From 35152593ed416ff6fdb80f5fafb16513b3cd851a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:20:45 +0000 Subject: [PATCH] Bump github.com/quic-go/quic-go from 0.38.1 to 0.42.0 Bumps [github.com/quic-go/quic-go](https://github.com/quic-go/quic-go) from 0.38.1 to 0.42.0. - [Release notes](https://github.com/quic-go/quic-go/releases) - [Changelog](https://github.com/quic-go/quic-go/blob/master/Changelog.md) - [Commits](https://github.com/quic-go/quic-go/compare/v0.38.1...v0.42.0) --- updated-dependencies: - dependency-name: github.com/quic-go/quic-go dependency-type: indirect ... Signed-off-by: dependabot[bot] --- go.mod | 5 +- go.sum | 18 +- vendor/github.com/golang/mock/CONTRIBUTORS | 37 - .../golang/mock/mockgen/version.1.11.go | 26 - vendor/github.com/quic-go/qtls-go1-20/LICENSE | 27 - .../github.com/quic-go/qtls-go1-20/README.md | 6 - .../github.com/quic-go/qtls-go1-20/alert.go | 109 - vendor/github.com/quic-go/qtls-go1-20/auth.go | 293 --- .../github.com/quic-go/qtls-go1-20/cache.go | 95 - .../quic-go/qtls-go1-20/cipher_suites.go | 691 ------ .../github.com/quic-go/qtls-go1-20/common.go | 1454 ------------- vendor/github.com/quic-go/qtls-go1-20/conn.go | 1643 -------------- .../quic-go/qtls-go1-20/handshake_client.go | 1130 ---------- .../qtls-go1-20/handshake_client_tls13.go | 782 ------- .../quic-go/qtls-go1-20/handshake_messages.go | 1886 ----------------- .../quic-go/qtls-go1-20/handshake_server.go | 899 -------- .../qtls-go1-20/handshake_server_tls13.go | 979 --------- .../quic-go/qtls-go1-20/key_agreement.go | 366 ---- .../quic-go/qtls-go1-20/key_schedule.go | 159 -- .../quic-go/qtls-go1-20/notboring.go | 18 - vendor/github.com/quic-go/qtls-go1-20/prf.go | 283 --- vendor/github.com/quic-go/qtls-go1-20/quic.go | 412 ---- .../github.com/quic-go/qtls-go1-20/ticket.go | 203 -- vendor/github.com/quic-go/qtls-go1-20/tls.go | 356 ---- .../github.com/quic-go/qtls-go1-20/unsafe.go | 101 - .../github.com/quic-go/quic-go/.golangci.yml | 11 - vendor/github.com/quic-go/quic-go/README.md | 46 +- vendor/github.com/quic-go/quic-go/client.go | 10 +- .../github.com/quic-go/quic-go/closed_conn.go | 33 +- vendor/github.com/quic-go/quic-go/codecov.yml | 11 +- vendor/github.com/quic-go/quic-go/config.go | 62 +- .../quic-go/quic-go/conn_id_generator.go | 11 +- .../quic-go/quic-go/conn_id_manager.go | 2 +- .../github.com/quic-go/quic-go/connection.go | 315 +-- .../quic-go/quic-go/crypto_stream.go | 30 +- .../quic-go/quic-go/datagram_queue.go | 87 +- vendor/github.com/quic-go/quic-go/errors.go | 12 + vendor/github.com/quic-go/quic-go/framer.go | 69 +- .../github.com/quic-go/quic-go/interface.go | 60 +- .../quic-go/internal/ackhandler/ackhandler.go | 7 +- .../quic-go/internal/ackhandler/ecn.go | 296 +++ .../quic-go/internal/ackhandler/interfaces.go | 9 +- .../quic-go/internal/ackhandler/mockgen.go | 5 +- .../ackhandler/packet_number_generator.go | 2 +- .../ackhandler/received_packet_handler.go | 49 +- .../ackhandler/received_packet_tracker.go | 226 +- .../ackhandler/sent_packet_handler.go | 105 +- .../quic-go/internal/congestion/cubic.go | 3 +- .../internal/congestion/cubic_sender.go | 16 +- .../internal/congestion/hybrid_slow_start.go | 5 +- .../quic-go/internal/congestion/interface.go | 2 +- .../quic-go/internal/congestion/pacer.go | 24 +- .../flowcontrol/base_flow_controller.go | 6 +- .../flowcontrol/connection_flow_controller.go | 2 +- .../quic-go/internal/flowcontrol/interface.go | 7 +- .../flowcontrol/stream_flow_controller.go | 2 +- .../quic-go/internal/handshake/aead.go | 42 +- .../internal/handshake/cipher_suite.go | 6 +- .../internal/handshake/crypto_setup.go | 215 +- .../internal/handshake/header_protector.go | 9 +- .../quic-go/internal/handshake/hkdf.go | 2 +- .../internal/handshake/initial_aead.go | 8 +- .../quic-go/internal/handshake/retry.go | 2 +- .../internal/handshake/session_ticket.go | 19 +- .../internal/handshake/token_generator.go | 13 +- .../internal/handshake/token_protector.go | 33 +- .../internal/handshake/updatable_aead.go | 18 +- .../quic-go/internal/protocol/params.go | 13 - .../quic-go/internal/protocol/perspective.go | 4 +- .../quic-go/internal/protocol/protocol.go | 42 +- .../quic-go/internal/protocol/version.go | 59 +- .../quic-go/internal/qerr/error_codes.go | 5 +- .../quic-go/quic-go/internal/qerr/errors.go | 4 +- ...{cipher_suite_go121.go => cipher_suite.go} | 14 - .../internal/qtls/client_session_cache.go | 25 +- .../quic-go/quic-go/internal/qtls/go120.go | 145 -- .../quic-go/internal/qtls/go_oldversion.go | 5 - .../internal/qtls/{go121.go => qtls.go} | 69 +- .../quic-go/quic-go/internal/utils/minmax.go | 38 +- .../internal/utils/ringbuffer/ringbuffer.go | 12 +- .../quic-go/internal/utils/rtt_stats.go | 8 +- .../quic-go/internal/wire/ack_frame.go | 10 +- .../internal/wire/connection_close_frame.go | 6 +- .../quic-go/internal/wire/crypto_frame.go | 8 +- .../internal/wire/data_blocked_frame.go | 6 +- .../quic-go/internal/wire/datagram_frame.go | 14 +- .../quic-go/internal/wire/extended_header.go | 8 +- .../quic-go/internal/wire/frame_parser.go | 21 +- .../internal/wire/handshake_done_frame.go | 4 +- .../quic-go/quic-go/internal/wire/header.go | 12 +- .../quic-go/internal/wire/interface.go | 10 +- .../quic-go/quic-go/internal/wire/log.go | 4 +- .../quic-go/internal/wire/max_data_frame.go | 6 +- .../internal/wire/max_stream_data_frame.go | 6 +- .../internal/wire/max_streams_frame.go | 6 +- .../internal/wire/new_connection_id_frame.go | 10 +- .../quic-go/internal/wire/new_token_frame.go | 6 +- .../internal/wire/path_challenge_frame.go | 6 +- .../internal/wire/path_response_frame.go | 6 +- .../quic-go/internal/wire/ping_frame.go | 4 +- .../internal/wire/reset_stream_frame.go | 6 +- .../wire/retire_connection_id_frame.go | 6 +- .../internal/wire/stop_sending_frame.go | 6 +- .../wire/stream_data_blocked_frame.go | 6 +- .../quic-go/internal/wire/stream_frame.go | 10 +- .../internal/wire/streams_blocked_frame.go | 6 +- .../internal/wire/transport_parameters.go | 36 +- .../internal/wire/version_negotiation.go | 31 +- .../quic-go/logging/connection_tracer.go | 263 +++ .../quic-go/quic-go/logging/interface.go | 60 +- .../quic-go/quic-go/logging/mockgen.go | 4 - .../quic-go/quic-go/logging/multiplex.go | 226 -- .../quic-go/quic-go/logging/null_tracer.go | 58 - .../quic-go/quic-go/logging/tracer.go | 59 + .../quic-go/quic-go/logging/types.go | 34 + vendor/github.com/quic-go/quic-go/mockgen.go | 51 +- vendor/github.com/quic-go/quic-go/oss-fuzz.sh | 6 +- .../quic-go/quic-go/packet_handler_map.go | 48 +- .../quic-go/quic-go/packet_packer.go | 96 +- .../quic-go/quic-go/packet_unpacker.go | 8 +- .../quic-go/quic-go/receive_stream.go | 4 - .../quic-go/quic-go/retransmission_queue.go | 6 +- .../github.com/quic-go/quic-go/send_conn.go | 56 +- .../github.com/quic-go/quic-go/send_queue.go | 13 +- .../github.com/quic-go/quic-go/send_stream.go | 22 +- vendor/github.com/quic-go/quic-go/server.go | 476 +++-- vendor/github.com/quic-go/quic-go/stream.go | 2 +- vendor/github.com/quic-go/quic-go/sys_conn.go | 8 +- .../quic-go/quic-go/sys_conn_helper_darwin.go | 2 + .../quic-go/sys_conn_helper_freebsd.go | 2 + .../quic-go/quic-go/sys_conn_helper_linux.go | 13 + .../quic-go/sys_conn_helper_nonlinux.go | 1 + .../quic-go/quic-go/sys_conn_oob.go | 66 +- .../github.com/quic-go/quic-go/token_store.go | 5 +- vendor/github.com/quic-go/quic-go/tools.go | 2 +- .../github.com/quic-go/quic-go/transport.go | 131 +- .../golang => go.uber.org}/mock/AUTHORS | 0 .../golang => go.uber.org}/mock/LICENSE | 0 .../go.uber.org/mock/mockgen/generic_go118.go | 130 ++ .../mock/mockgen/generic_notgo118.go | 41 + .../mock/mockgen/mockgen.go | 318 ++- .../mock/mockgen/model/model.go | 70 +- .../mock/mockgen/parse.go | 349 ++- .../mock/mockgen/reflect.go | 16 +- .../mock/mockgen/version.go} | 6 +- vendor/golang.org/x/crypto/cryptobyte/asn1.go | 825 ------- .../x/crypto/cryptobyte/asn1/asn1.go | 46 - .../golang.org/x/crypto/cryptobyte/builder.go | 350 --- .../golang.org/x/crypto/cryptobyte/string.go | 183 -- .../x/exp/constraints/constraints.go | 50 - vendor/golang.org/x/exp/rand/exp.go | 221 ++ vendor/golang.org/x/exp/rand/normal.go | 156 ++ vendor/golang.org/x/exp/rand/rand.go | 372 ++++ vendor/golang.org/x/exp/rand/rng.go | 91 + vendor/golang.org/x/exp/rand/zipf.go | 77 + vendor/modules.txt | 19 +- 156 files changed, 4016 insertions(+), 15599 deletions(-) delete mode 100644 vendor/github.com/golang/mock/CONTRIBUTORS delete mode 100644 vendor/github.com/golang/mock/mockgen/version.1.11.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/LICENSE delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/README.md delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/alert.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/auth.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/cache.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/common.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/conn.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/handshake_client.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/handshake_server.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/key_agreement.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/key_schedule.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/notboring.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/prf.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/quic.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/ticket.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/tls.go delete mode 100644 vendor/github.com/quic-go/qtls-go1-20/unsafe.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go rename vendor/github.com/quic-go/quic-go/internal/qtls/{cipher_suite_go121.go => cipher_suite.go} (85%) delete mode 100644 vendor/github.com/quic-go/quic-go/internal/qtls/go120.go delete mode 100644 vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go rename vendor/github.com/quic-go/quic-go/internal/qtls/{go121.go => qtls.go} (58%) create mode 100644 vendor/github.com/quic-go/quic-go/logging/connection_tracer.go delete mode 100644 vendor/github.com/quic-go/quic-go/logging/mockgen.go delete mode 100644 vendor/github.com/quic-go/quic-go/logging/multiplex.go delete mode 100644 vendor/github.com/quic-go/quic-go/logging/null_tracer.go create mode 100644 vendor/github.com/quic-go/quic-go/logging/tracer.go rename vendor/{github.com/golang => go.uber.org}/mock/AUTHORS (100%) rename vendor/{github.com/golang => go.uber.org}/mock/LICENSE (100%) create mode 100644 vendor/go.uber.org/mock/mockgen/generic_go118.go create mode 100644 vendor/go.uber.org/mock/mockgen/generic_notgo118.go rename vendor/{github.com/golang => go.uber.org}/mock/mockgen/mockgen.go (62%) rename vendor/{github.com/golang => go.uber.org}/mock/mockgen/model/model.go (87%) rename vendor/{github.com/golang => go.uber.org}/mock/mockgen/parse.go (63%) rename vendor/{github.com/golang => go.uber.org}/mock/mockgen/reflect.go (93%) rename vendor/{github.com/golang/mock/mockgen/version.1.12.go => go.uber.org/mock/mockgen/version.go} (94%) delete mode 100644 vendor/golang.org/x/crypto/cryptobyte/asn1.go delete mode 100644 vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go delete mode 100644 vendor/golang.org/x/crypto/cryptobyte/builder.go delete mode 100644 vendor/golang.org/x/crypto/cryptobyte/string.go delete mode 100644 vendor/golang.org/x/exp/constraints/constraints.go create mode 100644 vendor/golang.org/x/exp/rand/exp.go create mode 100644 vendor/golang.org/x/exp/rand/normal.go create mode 100644 vendor/golang.org/x/exp/rand/rand.go create mode 100644 vendor/golang.org/x/exp/rand/rng.go create mode 100644 vendor/golang.org/x/exp/rand/zipf.go diff --git a/go.mod b/go.mod index 9cf925495..cd01cdf3c 100644 --- a/go.mod +++ b/go.mod @@ -88,7 +88,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.3 // indirect github.com/google/flatbuffers v1.12.1 // indirect @@ -127,8 +126,7 @@ require ( github.com/pires/go-proxyproto v0.6.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/quic-go/qtls-go1-20 v0.3.3 // indirect - github.com/quic-go/quic-go v0.38.1 // indirect + github.com/quic-go/quic-go v0.42.0 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rs/cors v1.8.2 // indirect github.com/skycoin/noise v0.0.0-20180327030543-2492fe189ae6 // indirect @@ -143,6 +141,7 @@ require ( github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 // indirect go.opencensus.io v0.23.0 // indirect + go.uber.org/mock v0.4.0 // indirect golang.org/x/arch v0.4.0 // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect diff --git a/go.sum b/go.sum index 519b3dc18..ac4fb0ec7 100644 --- a/go.sum +++ b/go.sum @@ -179,8 +179,6 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4er github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -401,10 +399,8 @@ github.com/pterm/pterm v0.12.36/go.mod h1:NjiL09hFhT/vWjQHSj1athJpx6H8cjpHXNAK5b github.com/pterm/pterm v0.12.40/go.mod h1:ffwPLwlbXxP+rxT0GsgDTzS3y3rmpAO1NMjUkGTYf8s= github.com/pterm/pterm v0.12.66 h1:bjsoMyUstaarzJ1NG7+1HGT7afR0JVMYsR3ooPeh4bo= github.com/pterm/pterm v0.12.66/go.mod h1:nFuT9ZVkkCi8o4L1dtWuYPwDQxggLh4C263qG5nTLpQ= -github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= -github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= -github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= +github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= +github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -514,7 +510,6 @@ github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 h1:EWU6Pktpas0n8lL github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37/go.mod h1:HpMP7DB2CyokmAh4lp0EQnnWhmycP/TvwBGzvuie+H0= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zcalusic/sysinfo v1.0.1 h1:cVh8q3codjh43AGRTa54dJ2Zq+qPejv8n2VWpxKViwc= github.com/zcalusic/sysinfo v1.0.1/go.mod h1:LxwKwtQdbTIQc65drhjQzYzt0o7jfB80LrrZm7SWn8o= @@ -528,6 +523,8 @@ go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -563,7 +560,6 @@ golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKG golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -581,7 +577,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= @@ -594,7 +589,6 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -619,7 +613,6 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -657,6 +650,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -671,7 +666,6 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= diff --git a/vendor/github.com/golang/mock/CONTRIBUTORS b/vendor/github.com/golang/mock/CONTRIBUTORS deleted file mode 100644 index def849cab..000000000 --- a/vendor/github.com/golang/mock/CONTRIBUTORS +++ /dev/null @@ -1,37 +0,0 @@ -# This is the official list of people who can contribute (and typically -# have contributed) code to the gomock repository. -# The AUTHORS file lists the copyright holders; this file -# lists people. For example, Google employees are listed here -# but not in AUTHORS, because Google holds the copyright. -# -# The submission process automatically checks to make sure -# that people submitting code are listed in this file (by email address). -# -# Names should be added to this file only after verifying that -# the individual or the individual's organization has agreed to -# the appropriate Contributor License Agreement, found here: -# -# http://code.google.com/legal/individual-cla-v1.0.html -# http://code.google.com/legal/corporate-cla-v1.0.html -# -# The agreement for individuals can be filled out on the web. -# -# When adding J Random Contributor's name to this file, -# either J's name or J's organization's name should be -# added to the AUTHORS file, depending on whether the -# individual or corporate CLA was used. - -# Names should be added to this file like so: -# Name -# -# An entry with two email addresses specifies that the -# first address should be used in the submit logs and -# that the second address should be recognized as the -# same person when interacting with Rietveld. - -# Please keep the list sorted. - -Aaron Jacobs -Alex Reece -David Symonds -Ryan Barrett diff --git a/vendor/github.com/golang/mock/mockgen/version.1.11.go b/vendor/github.com/golang/mock/mockgen/version.1.11.go deleted file mode 100644 index e6b25db23..000000000 --- a/vendor/github.com/golang/mock/mockgen/version.1.11.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !go1.12 - -package main - -import ( - "log" -) - -func printModuleVersion() { - log.Printf("No version information is available for Mockgen compiled with " + - "version 1.11") -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/LICENSE b/vendor/github.com/quic-go/qtls-go1-20/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/quic-go/qtls-go1-20/README.md b/vendor/github.com/quic-go/qtls-go1-20/README.md deleted file mode 100644 index 2beaa2f23..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# qtls - -[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-20.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-20) -[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml) - -This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/quic-go/quic-go). diff --git a/vendor/github.com/quic-go/qtls-go1-20/alert.go b/vendor/github.com/quic-go/qtls-go1-20/alert.go deleted file mode 100644 index 687ada843..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/alert.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import "strconv" - -// An AlertError is a TLS alert. -// -// When using a QUIC transport, QUICConn methods will return an error -// which wraps AlertError rather than sending a TLS alert. -type AlertError uint8 - -func (e AlertError) Error() string { - return alert(e).String() -} - -type alert uint8 - -const ( - // alert level - alertLevelWarning = 1 - alertLevelError = 2 -) - -const ( - alertCloseNotify alert = 0 - alertUnexpectedMessage alert = 10 - alertBadRecordMAC alert = 20 - alertDecryptionFailed alert = 21 - alertRecordOverflow alert = 22 - alertDecompressionFailure alert = 30 - alertHandshakeFailure alert = 40 - alertBadCertificate alert = 42 - alertUnsupportedCertificate alert = 43 - alertCertificateRevoked alert = 44 - alertCertificateExpired alert = 45 - alertCertificateUnknown alert = 46 - alertIllegalParameter alert = 47 - alertUnknownCA alert = 48 - alertAccessDenied alert = 49 - alertDecodeError alert = 50 - alertDecryptError alert = 51 - alertExportRestriction alert = 60 - alertProtocolVersion alert = 70 - alertInsufficientSecurity alert = 71 - alertInternalError alert = 80 - alertInappropriateFallback alert = 86 - alertUserCanceled alert = 90 - alertNoRenegotiation alert = 100 - alertMissingExtension alert = 109 - alertUnsupportedExtension alert = 110 - alertCertificateUnobtainable alert = 111 - alertUnrecognizedName alert = 112 - alertBadCertificateStatusResponse alert = 113 - alertBadCertificateHashValue alert = 114 - alertUnknownPSKIdentity alert = 115 - alertCertificateRequired alert = 116 - alertNoApplicationProtocol alert = 120 -) - -var alertText = map[alert]string{ - alertCloseNotify: "close notify", - alertUnexpectedMessage: "unexpected message", - alertBadRecordMAC: "bad record MAC", - alertDecryptionFailed: "decryption failed", - alertRecordOverflow: "record overflow", - alertDecompressionFailure: "decompression failure", - alertHandshakeFailure: "handshake failure", - alertBadCertificate: "bad certificate", - alertUnsupportedCertificate: "unsupported certificate", - alertCertificateRevoked: "revoked certificate", - alertCertificateExpired: "expired certificate", - alertCertificateUnknown: "unknown certificate", - alertIllegalParameter: "illegal parameter", - alertUnknownCA: "unknown certificate authority", - alertAccessDenied: "access denied", - alertDecodeError: "error decoding message", - alertDecryptError: "error decrypting message", - alertExportRestriction: "export restriction", - alertProtocolVersion: "protocol version not supported", - alertInsufficientSecurity: "insufficient security level", - alertInternalError: "internal error", - alertInappropriateFallback: "inappropriate fallback", - alertUserCanceled: "user canceled", - alertNoRenegotiation: "no renegotiation", - alertMissingExtension: "missing extension", - alertUnsupportedExtension: "unsupported extension", - alertCertificateUnobtainable: "certificate unobtainable", - alertUnrecognizedName: "unrecognized name", - alertBadCertificateStatusResponse: "bad certificate status response", - alertBadCertificateHashValue: "bad certificate hash value", - alertUnknownPSKIdentity: "unknown PSK identity", - alertCertificateRequired: "certificate required", - alertNoApplicationProtocol: "no application protocol", -} - -func (e alert) String() string { - s, ok := alertText[e] - if ok { - return "tls: " + s - } - return "tls: alert(" + strconv.Itoa(int(e)) + ")" -} - -func (e alert) Error() string { - return e.String() -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/auth.go b/vendor/github.com/quic-go/qtls-go1-20/auth.go deleted file mode 100644 index effc9aced..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/auth.go +++ /dev/null @@ -1,293 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rsa" - "errors" - "fmt" - "hash" - "io" -) - -// verifyHandshakeSignature verifies a signature against pre-hashed -// (if required) handshake contents. -func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { - switch sigType { - case signatureECDSA: - pubKey, ok := pubkey.(*ecdsa.PublicKey) - if !ok { - return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) - } - if !ecdsa.VerifyASN1(pubKey, signed, sig) { - return errors.New("ECDSA verification failure") - } - case signatureEd25519: - pubKey, ok := pubkey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) - } - if !ed25519.Verify(pubKey, signed, sig) { - return errors.New("Ed25519 verification failure") - } - case signaturePKCS1v15: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { - return err - } - case signatureRSAPSS: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} - if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { - return err - } - default: - return errors.New("internal error: unknown signature type") - } - return nil -} - -const ( - serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" - clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" -) - -var signaturePadding = []byte{ - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, -} - -// signedMessage returns the pre-hashed (if necessary) message to be signed by -// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. -func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { - if sigHash == directSigning { - b := &bytes.Buffer{} - b.Write(signaturePadding) - io.WriteString(b, context) - b.Write(transcript.Sum(nil)) - return b.Bytes() - } - h := sigHash.New() - h.Write(signaturePadding) - io.WriteString(h, context) - h.Write(transcript.Sum(nil)) - return h.Sum(nil) -} - -// typeAndHashFromSignatureScheme returns the corresponding signature type and -// crypto.Hash for a given TLS SignatureScheme. -func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { - switch signatureAlgorithm { - case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: - sigType = signaturePKCS1v15 - case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: - sigType = signatureRSAPSS - case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: - sigType = signatureECDSA - case Ed25519: - sigType = signatureEd25519 - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - switch signatureAlgorithm { - case PKCS1WithSHA1, ECDSAWithSHA1: - hash = crypto.SHA1 - case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: - hash = crypto.SHA256 - case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: - hash = crypto.SHA384 - case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: - hash = crypto.SHA512 - case Ed25519: - hash = directSigning - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - return sigType, hash, nil -} - -// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for -// a given public key used with TLS 1.0 and 1.1, before the introduction of -// signature algorithm negotiation. -func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { - switch pub.(type) { - case *rsa.PublicKey: - return signaturePKCS1v15, crypto.MD5SHA1, nil - case *ecdsa.PublicKey: - return signatureECDSA, crypto.SHA1, nil - case ed25519.PublicKey: - // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, - // but it requires holding on to a handshake transcript to do a - // full signature, and not even OpenSSL bothers with the - // complexity, so we can't even test it properly. - return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") - default: - return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) - } -} - -var rsaSignatureSchemes = []struct { - scheme SignatureScheme - minModulusBytes int - maxVersion uint16 -}{ - // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires - // emLen >= hLen + sLen + 2 - {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, - // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires - // emLen >= len(prefix) + hLen + 11 - // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. - {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, - {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, - {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, - {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, -} - -// signatureSchemesForCertificate returns the list of supported SignatureSchemes -// for a given certificate, based on the public key and the protocol version, -// and optionally filtered by its explicit SupportedSignatureAlgorithms. -// -// This function must be kept in sync with supportedSignatureAlgorithms. -// FIPS filtering is applied in the caller, selectSignatureScheme. -func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil - } - - var sigAlgs []SignatureScheme - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - if version != VersionTLS13 { - // In TLS 1.2 and earlier, ECDSA algorithms are not - // constrained to a single curve. - sigAlgs = []SignatureScheme{ - ECDSAWithP256AndSHA256, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - ECDSAWithSHA1, - } - break - } - switch pub.Curve { - case elliptic.P256(): - sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} - case elliptic.P384(): - sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} - case elliptic.P521(): - sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} - default: - return nil - } - case *rsa.PublicKey: - size := pub.Size() - sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) - for _, candidate := range rsaSignatureSchemes { - if size >= candidate.minModulusBytes && version <= candidate.maxVersion { - sigAlgs = append(sigAlgs, candidate.scheme) - } - } - case ed25519.PublicKey: - sigAlgs = []SignatureScheme{Ed25519} - default: - return nil - } - - if cert.SupportedSignatureAlgorithms != nil { - var filteredSigAlgs []SignatureScheme - for _, sigAlg := range sigAlgs { - if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { - filteredSigAlgs = append(filteredSigAlgs, sigAlg) - } - } - return filteredSigAlgs - } - return sigAlgs -} - -// selectSignatureScheme picks a SignatureScheme from the peer's preference list -// that works with the selected certificate. It's only called for protocol -// versions that support signature algorithms, so TLS 1.2 and 1.3. -func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { - supportedAlgs := signatureSchemesForCertificate(vers, c) - if len(supportedAlgs) == 0 { - return 0, unsupportedCertificateError(c) - } - if len(peerAlgs) == 0 && vers == VersionTLS12 { - // For TLS 1.2, if the client didn't send signature_algorithms then we - // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. - peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} - } - // Pick signature scheme in the peer's preference order, as our - // preference order is not configurable. - for _, preferredAlg := range peerAlgs { - if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) { - continue - } - if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { - return preferredAlg, nil - } - } - return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") -} - -// unsupportedCertificateError returns a helpful error for certificates with -// an unsupported private key. -func unsupportedCertificateError(cert *Certificate) error { - switch cert.PrivateKey.(type) { - case rsa.PrivateKey, ecdsa.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", - cert.PrivateKey, cert.PrivateKey) - case *ed25519.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") - } - - signer, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", - cert.PrivateKey) - } - - switch pub := signer.Public().(type) { - case *ecdsa.PublicKey: - switch pub.Curve { - case elliptic.P256(): - case elliptic.P384(): - case elliptic.P521(): - default: - return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) - } - case *rsa.PublicKey: - return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") - case ed25519.PublicKey: - default: - return fmt.Errorf("tls: unsupported certificate key (%T)", pub) - } - - if cert.SupportedSignatureAlgorithms != nil { - return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") - } - - return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/cache.go b/vendor/github.com/quic-go/qtls-go1-20/cache.go deleted file mode 100644 index 99e0c5fb8..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/cache.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto/x509" - "runtime" - "sync" - "sync/atomic" -) - -type cacheEntry struct { - refs atomic.Int64 - cert *x509.Certificate -} - -// certCache implements an intern table for reference counted x509.Certificates, -// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This -// allows for a single x509.Certificate to be kept in memory and referenced from -// multiple Conns. Returned references should not be mutated by callers. Certificates -// are still safe to use after they are removed from the cache. -// -// Certificates are returned wrapped in a activeCert struct that should be held by -// the caller. When references to the activeCert are freed, the number of references -// to the certificate in the cache is decremented. Once the number of references -// reaches zero, the entry is evicted from the cache. -// -// The main difference between this implementation and CRYPTO_BUFFER_POOL is that -// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data, -// rather than specific structures. Since we only care about x509.Certificates, -// certCache is implemented as a specific cache, rather than a generic one. -// -// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h -// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c -// for the BoringSSL reference. -type certCache struct { - sync.Map -} - -var clientCertCache = new(certCache) - -// activeCert is a handle to a certificate held in the cache. Once there are -// no alive activeCerts for a given certificate, the certificate is removed -// from the cache by a finalizer. -type activeCert struct { - cert *x509.Certificate -} - -// active increments the number of references to the entry, wraps the -// certificate in the entry in a activeCert, and sets the finalizer. -// -// Note that there is a race between active and the finalizer set on the -// returned activeCert, triggered if active is called after the ref count is -// decremented such that refs may be > 0 when evict is called. We consider this -// safe, since the caller holding an activeCert for an entry that is no longer -// in the cache is fine, with the only side effect being the memory overhead of -// there being more than one distinct reference to a certificate alive at once. -func (cc *certCache) active(e *cacheEntry) *activeCert { - e.refs.Add(1) - a := &activeCert{e.cert} - runtime.SetFinalizer(a, func(_ *activeCert) { - if e.refs.Add(-1) == 0 { - cc.evict(e) - } - }) - return a -} - -// evict removes a cacheEntry from the cache. -func (cc *certCache) evict(e *cacheEntry) { - cc.Delete(string(e.cert.Raw)) -} - -// newCert returns a x509.Certificate parsed from der. If there is already a copy -// of the certificate in the cache, a reference to the existing certificate will -// be returned. Otherwise, a fresh certificate will be added to the cache, and -// the reference returned. The returned reference should not be mutated. -func (cc *certCache) newCert(der []byte) (*activeCert, error) { - if entry, ok := cc.Load(string(der)); ok { - return cc.active(entry.(*cacheEntry)), nil - } - - cert, err := x509.ParseCertificate(der) - if err != nil { - return nil, err - } - - entry := &cacheEntry{cert: cert} - if entry, loaded := cc.LoadOrStore(string(der), entry); loaded { - return cc.active(entry.(*cacheEntry)), nil - } - return cc.active(entry), nil -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go b/vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go deleted file mode 100644 index 2946ffb32..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go +++ /dev/null @@ -1,691 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/des" - "crypto/hmac" - "crypto/rc4" - "crypto/sha1" - "crypto/sha256" - "fmt" - "hash" - "runtime" - - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/sys/cpu" -) - -// CipherSuite is a TLS cipher suite. Note that most functions in this package -// accept and expose cipher suite IDs instead of this type. -type CipherSuite struct { - ID uint16 - Name string - - // Supported versions is the list of TLS protocol versions that can - // negotiate this cipher suite. - SupportedVersions []uint16 - - // Insecure is true if the cipher suite has known security issues - // due to its primitives, design, or implementation. - Insecure bool -} - -var ( - supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} - supportedOnlyTLS12 = []uint16{VersionTLS12} - supportedOnlyTLS13 = []uint16{VersionTLS13} -) - -// CipherSuites returns a list of cipher suites currently implemented by this -// package, excluding those with security issues, which are returned by -// InsecureCipherSuites. -// -// The list is sorted by ID. Note that the default cipher suites selected by -// this package might depend on logic that can't be captured by a static list, -// and might not match those returned by this function. -func CipherSuites() []*CipherSuite { - return []*CipherSuite{ - {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - - {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, - {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, - {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, - - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - } -} - -// InsecureCipherSuites returns a list of cipher suites currently implemented by -// this package and which have security issues. -// -// Most applications should not use the cipher suites in this list, and should -// only use those returned by CipherSuites. -func InsecureCipherSuites() []*CipherSuite { - // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See - // cipherSuitesPreferenceOrder for details. - return []*CipherSuite{ - {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - } -} - -// CipherSuiteName returns the standard name for the passed cipher suite ID -// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation -// of the ID value if the cipher suite is not implemented by this package. -func CipherSuiteName(id uint16) string { - for _, c := range CipherSuites() { - if c.ID == id { - return c.Name - } - } - for _, c := range InsecureCipherSuites() { - if c.ID == id { - return c.Name - } - } - return fmt.Sprintf("0x%04X", id) -} - -const ( - // suiteECDHE indicates that the cipher suite involves elliptic curve - // Diffie-Hellman. This means that it should only be selected when the - // client indicates that it supports ECC with a curve and point format - // that we're happy with. - suiteECDHE = 1 << iota - // suiteECSign indicates that the cipher suite involves an ECDSA or - // EdDSA signature and therefore may only be selected when the server's - // certificate is ECDSA or EdDSA. If this is not set then the cipher suite - // is RSA based. - suiteECSign - // suiteTLS12 indicates that the cipher suite should only be advertised - // and accepted when using TLS 1.2. - suiteTLS12 - // suiteSHA384 indicates that the cipher suite uses SHA384 as the - // handshake hash. - suiteSHA384 -) - -// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange -// mechanism, as well as the cipher+MAC pair or the AEAD. -type cipherSuite struct { - id uint16 - // the lengths, in bytes, of the key material needed for each component. - keyLen int - macLen int - ivLen int - ka func(version uint16) keyAgreement - // flags is a bitmask of the suite* values, above. - flags int - cipher func(key, iv []byte, isRead bool) any - mac func(key []byte) hash.Hash - aead func(key, fixedNonce []byte) aead -} - -var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, -} - -// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which -// is also in supportedIDs and passes the ok filter. -func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { - for _, id := range ids { - candidate := cipherSuiteByID(id) - if candidate == nil || !ok(candidate) { - continue - } - - for _, suppID := range supportedIDs { - if id == suppID { - return candidate - } - } - } - return nil -} - -// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash -// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. -type cipherSuiteTLS13 struct { - id uint16 - keyLen int - aead func(key, fixedNonce []byte) aead - hash crypto.Hash -} - -var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. - {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, - {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, - {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, -} - -// cipherSuitesPreferenceOrder is the order in which we'll select (on the -// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. -// -// Cipher suites are filtered but not reordered based on the application and -// peer's preferences, meaning we'll never select a suite lower in this list if -// any higher one is available. This makes it more defensible to keep weaker -// cipher suites enabled, especially on the server side where we get the last -// word, since there are no known downgrade attacks on cipher suites selection. -// -// The list is sorted by applying the following priority rules, stopping at the -// first (most important) applicable one: -// -// - Anything else comes before RC4 -// -// RC4 has practically exploitable biases. See https://www.rc4nomore.com. -// -// - Anything else comes before CBC_SHA256 -// -// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 -// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. -// -// - Anything else comes before 3DES -// -// 3DES has 64-bit blocks, which makes it fundamentally susceptible to -// birthday attacks. See https://sweet32.info. -// -// - ECDHE comes before anything else -// -// Once we got the broken stuff out of the way, the most important -// property a cipher suite can have is forward secrecy. We don't -// implement FFDHE, so that means ECDHE. -// -// - AEADs come before CBC ciphers -// -// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites -// are fundamentally fragile, and suffered from an endless sequence of -// padding oracle attacks. See https://eprint.iacr.org/2015/1129, -// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and -// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. -// -// - AES comes before ChaCha20 -// -// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster -// than ChaCha20Poly1305. -// -// When AES hardware is not available, AES-128-GCM is one or more of: much -// slower, way more complex, and less safe (because not constant time) -// than ChaCha20Poly1305. -// -// We use this list if we think both peers have AES hardware, and -// cipherSuitesPreferenceOrderNoAES otherwise. -// -// - AES-128 comes before AES-256 -// -// The only potential advantages of AES-256 are better multi-target -// margins, and hypothetical post-quantum properties. Neither apply to -// TLS, and AES-256 is slower due to its four extra rounds (which don't -// contribute to the advantages above). -// -// - ECDSA comes before RSA -// -// The relative order of ECDSA and RSA cipher suites doesn't matter, -// as they depend on the certificate. Pick one to get a stable order. -var cipherSuitesPreferenceOrder = []uint16{ - // AEADs w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - - // CBC w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - - // AEADs w/o ECDHE - TLS_RSA_WITH_AES_128_GCM_SHA256, - TLS_RSA_WITH_AES_256_GCM_SHA384, - - // CBC w/o ECDHE - TLS_RSA_WITH_AES_128_CBC_SHA, - TLS_RSA_WITH_AES_256_CBC_SHA, - - // 3DES - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_RSA_WITH_3DES_EDE_CBC_SHA, - - // CBC_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - - // RC4 - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -var cipherSuitesPreferenceOrderNoAES = []uint16{ - // ChaCha20Poly1305 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - - // AES-GCM w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - - // The rest of cipherSuitesPreferenceOrder. - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - TLS_RSA_WITH_AES_128_GCM_SHA256, - TLS_RSA_WITH_AES_256_GCM_SHA384, - TLS_RSA_WITH_AES_128_CBC_SHA, - TLS_RSA_WITH_AES_256_CBC_SHA, - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -// disabledCipherSuites are not used unless explicitly listed in -// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. -var disabledCipherSuites = []uint16{ - // CBC_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - - // RC4 - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -var ( - defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) - defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] -) - -// defaultCipherSuitesTLS13 is also the preference order, since there are no -// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as -// cipherSuitesPreferenceOrder applies. -var defaultCipherSuitesTLS13 = []uint16{ - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - TLS_CHACHA20_POLY1305_SHA256, -} - -var defaultCipherSuitesTLS13NoAES = []uint16{ - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, -} - -var ( - hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ - hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && - (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - - hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || - runtime.GOARCH == "arm64" && hasGCMAsmARM64 || - runtime.GOARCH == "s390x" && hasGCMAsmS390X -) - -var aesgcmCiphers = map[uint16]bool{ - // TLS 1.2 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, - // TLS 1.3 - TLS_AES_128_GCM_SHA256: true, - TLS_AES_256_GCM_SHA384: true, -} - -var nonAESGCMAEADCiphers = map[uint16]bool{ - // TLS 1.2 - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, - // TLS 1.3 - TLS_CHACHA20_POLY1305_SHA256: true, -} - -// aesgcmPreferred returns whether the first known cipher in the preference list -// is an AES-GCM cipher, implying the peer has hardware support for it. -func aesgcmPreferred(ciphers []uint16) bool { - for _, cID := range ciphers { - if c := cipherSuiteByID(cID); c != nil { - return aesgcmCiphers[cID] - } - if c := cipherSuiteTLS13ByID(cID); c != nil { - return aesgcmCiphers[cID] - } - } - return false -} - -func cipherRC4(key, iv []byte, isRead bool) any { - cipher, _ := rc4.NewCipher(key) - return cipher -} - -func cipher3DES(key, iv []byte, isRead bool) any { - block, _ := des.NewTripleDESCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -func cipherAES(key, iv []byte, isRead bool) any { - block, _ := aes.NewCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -// macSHA1 returns a SHA-1 based constant time MAC. -func macSHA1(key []byte) hash.Hash { - h := sha1.New - h = newConstantTimeHash(h) - return hmac.New(h, key) -} - -// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and -// is currently only used in disabled-by-default cipher suites. -func macSHA256(key []byte) hash.Hash { - return hmac.New(sha256.New, key) -} - -type aead interface { - cipher.AEAD - - // explicitNonceLen returns the number of bytes of explicit nonce - // included in each record. This is eight for older AEADs and - // zero for modern ones. - explicitNonceLen() int -} - -const ( - aeadNonceLength = 12 - noncePrefixLength = 4 -) - -// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to -// each call. -type prefixNonceAEAD struct { - // nonce contains the fixed part of the nonce in the first four bytes. - nonce [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } -func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } - -func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - copy(f.nonce[4:], nonce) - return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) -} - -func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - copy(f.nonce[4:], nonce) - return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) -} - -// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce -// before each call. -type xorNonceAEAD struct { - nonceMask [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number -func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } - -func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result -} - -func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result, err -} - -func aeadAESGCM(key, noncePrefix []byte) aead { - if len(noncePrefix) != noncePrefixLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - var aead cipher.AEAD - aead, err = cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &prefixNonceAEAD{aead: aead} - copy(ret.nonce[:], noncePrefix) - return ret -} - -func aeadAESGCMTLS13(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -func aeadChaCha20Poly1305(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aead, err := chacha20poly1305.New(key) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -type constantTimeHash interface { - hash.Hash - ConstantTimeSum(b []byte) []byte -} - -// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces -// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. -type cthWrapper struct { - h constantTimeHash -} - -func (c *cthWrapper) Size() int { return c.h.Size() } -func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } -func (c *cthWrapper) Reset() { c.h.Reset() } -func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } -func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } - -func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { - return func() hash.Hash { - return &cthWrapper{h().(constantTimeHash)} - } -} - -// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. -func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { - h.Reset() - h.Write(seq) - h.Write(header) - h.Write(data) - res := h.Sum(out) - if extra != nil { - h.Write(extra) - } - return res -} - -func rsaKA(version uint16) keyAgreement { - return rsaKeyAgreement{} -} - -func ecdheECDSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: false, - version: version, - } -} - -func ecdheRSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: true, - version: version, - } -} - -// mutualCipherSuite returns a cipherSuite given a list of supported -// ciphersuites and the id requested by the peer. -func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { - for _, id := range have { - if id == want { - return cipherSuiteByID(id) - } - } - return nil -} - -func cipherSuiteByID(id uint16) *cipherSuite { - for _, cipherSuite := range cipherSuites { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { - for _, id := range have { - if id == want { - return cipherSuiteTLS13ByID(id) - } - } - return nil -} - -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { - for _, cipherSuite := range cipherSuitesTLS13 { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -// A list of cipher suite IDs that are, or have been, implemented by this -// package. -// -// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml -const ( - // TLS 1.0 - 1.2 cipher suites. - TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a - TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f - TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c - TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c - TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a - TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 - - // TLS 1.3 cipher suites. - TLS_AES_128_GCM_SHA256 uint16 = 0x1301 - TLS_AES_256_GCM_SHA384 uint16 = 0x1302 - TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 - - // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator - // that the client is doing version fallback. See RFC 7507. - TLS_FALLBACK_SCSV uint16 = 0x5600 - - // Legacy names for the corresponding cipher suites with the correct _SHA256 - // suffix, retained for backward compatibility. - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 -) diff --git a/vendor/github.com/quic-go/qtls-go1-20/common.go b/vendor/github.com/quic-go/qtls-go1-20/common.go deleted file mode 100644 index 841c1a44b..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/common.go +++ /dev/null @@ -1,1454 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "container/list" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/sha512" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "io" - "net" - "strings" - "sync" - "time" -) - -const ( - VersionTLS10 = 0x0301 - VersionTLS11 = 0x0302 - VersionTLS12 = 0x0303 - VersionTLS13 = 0x0304 - - // Deprecated: SSLv3 is cryptographically broken, and is no longer - // supported by this package. See golang.org/issue/32716. - VersionSSL30 = 0x0300 -) - -const ( - maxPlaintext = 16384 // maximum plaintext payload length - maxCiphertext = 16384 + 2048 // maximum ciphertext payload length - maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 - recordHeaderLen = 5 // record header length - maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) - maxUselessRecords = 16 // maximum number of consecutive non-advancing records -) - -// TLS record types. -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -// TLS handshake message types. -const ( - typeHelloRequest uint8 = 0 - typeClientHello uint8 = 1 - typeServerHello uint8 = 2 - typeNewSessionTicket uint8 = 4 - typeEndOfEarlyData uint8 = 5 - typeEncryptedExtensions uint8 = 8 - typeCertificate uint8 = 11 - typeServerKeyExchange uint8 = 12 - typeCertificateRequest uint8 = 13 - typeServerHelloDone uint8 = 14 - typeCertificateVerify uint8 = 15 - typeClientKeyExchange uint8 = 16 - typeFinished uint8 = 20 - typeCertificateStatus uint8 = 22 - typeKeyUpdate uint8 = 24 - typeNextProtocol uint8 = 67 // Not IANA assigned - typeMessageHash uint8 = 254 // synthetic message -) - -// TLS compression types. -const ( - compressionNone uint8 = 0 -) - -// TLS extension numbers -const ( - extensionServerName uint16 = 0 - extensionStatusRequest uint16 = 5 - extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 - extensionSupportedPoints uint16 = 11 - extensionSignatureAlgorithms uint16 = 13 - extensionALPN uint16 = 16 - extensionSCT uint16 = 18 - extensionSessionTicket uint16 = 35 - extensionPreSharedKey uint16 = 41 - extensionEarlyData uint16 = 42 - extensionSupportedVersions uint16 = 43 - extensionCookie uint16 = 44 - extensionPSKModes uint16 = 45 - extensionCertificateAuthorities uint16 = 47 - extensionSignatureAlgorithmsCert uint16 = 50 - extensionKeyShare uint16 = 51 - extensionQUICTransportParameters uint16 = 57 - extensionRenegotiationInfo uint16 = 0xff01 -) - -// TLS signaling cipher suite values -const ( - scsvRenegotiation uint16 = 0x00ff -) - -// CurveID is a tls.CurveID -type CurveID = tls.CurveID - -const ( - CurveP256 CurveID = 23 - CurveP384 CurveID = 24 - CurveP521 CurveID = 25 - X25519 CurveID = 29 -) - -// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. -type keyShare struct { - group CurveID - data []byte -} - -// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. -const ( - pskModePlain uint8 = 0 - pskModeDHE uint8 = 1 -) - -// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved -// session. See RFC 8446, Section 4.2.11. -type pskIdentity struct { - label []byte - obfuscatedTicketAge uint32 -} - -// TLS Elliptic Curve Point Formats -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 -const ( - pointFormatUncompressed uint8 = 0 -) - -// TLS CertificateStatusType (RFC 3546) -const ( - statusTypeOCSP uint8 = 1 -) - -// Certificate types (for certificateRequestMsg) -const ( - certTypeRSASign = 1 - certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. -) - -// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with -// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. -const ( - signaturePKCS1v15 uint8 = iota + 225 - signatureRSAPSS - signatureECDSA - signatureEd25519 -) - -// directSigning is a standard Hash value that signals that no pre-hashing -// should be performed, and that the input should be signed directly. It is the -// hash function associated with the Ed25519 signature scheme. -var directSigning crypto.Hash = 0 - -// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that -// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ -// CertificateRequest. The two fields are merged to match with TLS 1.3. -// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. -var defaultSupportedSignatureAlgorithms = []SignatureScheme{ - PSSWithSHA256, - ECDSAWithP256AndSHA256, - Ed25519, - PSSWithSHA384, - PSSWithSHA512, - PKCS1WithSHA256, - PKCS1WithSHA384, - PKCS1WithSHA512, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - PKCS1WithSHA1, - ECDSAWithSHA1, -} - -// helloRetryRequestRandom is set as the Random value of a ServerHello -// to signal that the message is actually a HelloRetryRequest. -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -// testingOnlyForceDowngradeCanary is set in tests to force the server side to -// include downgrade canaries even if it's using its highers supported version. -var testingOnlyForceDowngradeCanary bool - -type ConnectionState = tls.ConnectionState - -// ConnectionState records basic TLS details about the connection. -type connectionState struct { - // Version is the TLS version used by the connection (e.g. VersionTLS12). - Version uint16 - - // HandshakeComplete is true if the handshake has concluded. - HandshakeComplete bool - - // DidResume is true if this connection was successfully resumed from a - // previous session with a session ticket or similar mechanism. - DidResume bool - - // CipherSuite is the cipher suite negotiated for the connection (e.g. - // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). - CipherSuite uint16 - - // NegotiatedProtocol is the application protocol negotiated with ALPN. - NegotiatedProtocol string - - // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. - // - // Deprecated: this value is always true. - NegotiatedProtocolIsMutual bool - - // ServerName is the value of the Server Name Indication extension sent by - // the client. It's available both on the server and on the client side. - ServerName string - - // PeerCertificates are the parsed certificates sent by the peer, in the - // order in which they were sent. The first element is the leaf certificate - // that the connection is verified against. - // - // On the client side, it can't be empty. On the server side, it can be - // empty if Config.ClientAuth is not RequireAnyClientCert or - // RequireAndVerifyClientCert. - // - // PeerCertificates and its contents should not be modified. - PeerCertificates []*x509.Certificate - - // VerifiedChains is a list of one or more chains where the first element is - // PeerCertificates[0] and the last element is from Config.RootCAs (on the - // client side) or Config.ClientCAs (on the server side). - // - // On the client side, it's set if Config.InsecureSkipVerify is false. On - // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven - // (and the peer provided a certificate) or RequireAndVerifyClientCert. - // - // VerifiedChains and its contents should not be modified. - VerifiedChains [][]*x509.Certificate - - // SignedCertificateTimestamps is a list of SCTs provided by the peer - // through the TLS handshake for the leaf certificate, if any. - SignedCertificateTimestamps [][]byte - - // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) - // response provided by the peer for the leaf certificate, if any. - OCSPResponse []byte - - // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, - // Section 3). This value will be nil for TLS 1.3 connections and for all - // resumed connections. - // - // Deprecated: there are conditions in which this value might not be unique - // to a connection. See the Security Considerations sections of RFC 5705 and - // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. - TLSUnique []byte - - // ekm is a closure exposed via ExportKeyingMaterial. - ekm func(label string, context []byte, length int) ([]byte, error) -} - -// ClientAuthType is tls.ClientAuthType -type ClientAuthType = tls.ClientAuthType - -const ( - NoClientCert = tls.NoClientCert - RequestClientCert = tls.RequestClientCert - RequireAnyClientCert = tls.RequireAnyClientCert - VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven - RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert -) - -// requiresClientCert reports whether the ClientAuthType requires a client -// certificate to be provided. -func requiresClientCert(c ClientAuthType) bool { - switch c { - case RequireAnyClientCert, RequireAndVerifyClientCert: - return true - default: - return false - } -} - -// ClientSessionState contains the state needed by clients to resume TLS -// sessions. -type ClientSessionState = tls.ClientSessionState - -type clientSessionState struct { - sessionTicket []uint8 // Encrypted ticket used for session resumption with server - vers uint16 // TLS version negotiated for the session - cipherSuite uint16 // Ciphersuite negotiated for the session - masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret - serverCertificates []*x509.Certificate // Certificate chain presented by the server - verifiedChains [][]*x509.Certificate // Certificate chains we built for verification - receivedAt time.Time // When the session ticket was received from the server - ocspResponse []byte // Stapled OCSP response presented by the server - scts [][]byte // SCTs presented by the server - - // TLS 1.3 fields. - nonce []byte // Ticket nonce sent by the server, to derive PSK - useBy time.Time // Expiration of the ticket lifetime as set by the server - ageAdd uint32 // Random obfuscation factor for sending the ticket age -} - -// ClientSessionCache is a cache of ClientSessionState objects that can be used -// by a client to resume a TLS session with a given server. ClientSessionCache -// implementations should expect to be called concurrently from different -// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not -// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which -// are supported via this interface. -type ClientSessionCache = tls.ClientSessionCache - -// SignatureScheme is a tls.SignatureScheme -type SignatureScheme = tls.SignatureScheme - -const ( - // RSASSA-PKCS1-v1_5 algorithms. - PKCS1WithSHA256 SignatureScheme = 0x0401 - PKCS1WithSHA384 SignatureScheme = 0x0501 - PKCS1WithSHA512 SignatureScheme = 0x0601 - - // RSASSA-PSS algorithms with public key OID rsaEncryption. - PSSWithSHA256 SignatureScheme = 0x0804 - PSSWithSHA384 SignatureScheme = 0x0805 - PSSWithSHA512 SignatureScheme = 0x0806 - - // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. - ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 - ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 - ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 - - // EdDSA algorithms. - Ed25519 SignatureScheme = 0x0807 - - // Legacy signature and hash algorithms for TLS 1.2. - PKCS1WithSHA1 SignatureScheme = 0x0201 - ECDSAWithSHA1 SignatureScheme = 0x0203 -) - -// ClientHelloInfo contains information from a ClientHello message in order to -// guide application logic in the GetCertificate and GetConfigForClient callbacks. -type ClientHelloInfo = tls.ClientHelloInfo - -type clientHelloInfo struct { - // CipherSuites lists the CipherSuites supported by the client (e.g. - // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). - CipherSuites []uint16 - - // ServerName indicates the name of the server requested by the client - // in order to support virtual hosting. ServerName is only set if the - // client is using SNI (see RFC 4366, Section 3.1). - ServerName string - - // SupportedCurves lists the elliptic curves supported by the client. - // SupportedCurves is set only if the Supported Elliptic Curves - // Extension is being used (see RFC 4492, Section 5.1.1). - SupportedCurves []CurveID - - // SupportedPoints lists the point formats supported by the client. - // SupportedPoints is set only if the Supported Point Formats Extension - // is being used (see RFC 4492, Section 5.1.2). - SupportedPoints []uint8 - - // SignatureSchemes lists the signature and hash schemes that the client - // is willing to verify. SignatureSchemes is set only if the Signature - // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). - SignatureSchemes []SignatureScheme - - // SupportedProtos lists the application protocols supported by the client. - // SupportedProtos is set only if the Application-Layer Protocol - // Negotiation Extension is being used (see RFC 7301, Section 3.1). - // - // Servers can select a protocol by setting Config.NextProtos in a - // GetConfigForClient return value. - SupportedProtos []string - - // SupportedVersions lists the TLS versions supported by the client. - // For TLS versions less than 1.3, this is extrapolated from the max - // version advertised by the client, so values other than the greatest - // might be rejected if used. - SupportedVersions []uint16 - - // Conn is the underlying net.Conn for the connection. Do not read - // from, or write to, this connection; that will cause the TLS - // connection to fail. - Conn net.Conn - - // config is embedded by the GetCertificate or GetConfigForClient caller, - // for use with SupportsCertificate. - config *Config - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *clientHelloInfo) Context() context.Context { - return c.ctx -} - -// CertificateRequestInfo contains information from a server's -// CertificateRequest message, which is used to demand a certificate and proof -// of control from a client. -type CertificateRequestInfo = tls.CertificateRequestInfo - -type certificateRequestInfo struct { - // AcceptableCAs contains zero or more, DER-encoded, X.501 - // Distinguished Names. These are the names of root or intermediate CAs - // that the server wishes the returned certificate to be signed by. An - // empty slice indicates that the server has no preference. - AcceptableCAs [][]byte - - // SignatureSchemes lists the signature schemes that the server is - // willing to verify. - SignatureSchemes []SignatureScheme - - // Version is the TLS version that was negotiated for this connection. - Version uint16 - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *certificateRequestInfo) Context() context.Context { - return c.ctx -} - -// RenegotiationSupport enumerates the different levels of support for TLS -// renegotiation. TLS renegotiation is the act of performing subsequent -// handshakes on a connection after the first. This significantly complicates -// the state machine and has been the source of numerous, subtle security -// issues. Initiating a renegotiation is not supported, but support for -// accepting renegotiation requests may be enabled. -// -// Even when enabled, the server may not change its identity between handshakes -// (i.e. the leaf certificate must be the same). Additionally, concurrent -// handshake and application data flow is not permitted so renegotiation can -// only be used with protocols that synchronise with the renegotiation, such as -// HTTPS. -// -// Renegotiation is not defined in TLS 1.3. -type RenegotiationSupport = tls.RenegotiationSupport - -const ( - // RenegotiateNever disables renegotiation. - RenegotiateNever = tls.RenegotiateNever - - // RenegotiateOnceAsClient allows a remote server to request - // renegotiation once per connection. - RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient - - // RenegotiateFreelyAsClient allows a remote server to repeatedly - // request renegotiation. - RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient -) - -// A Config structure is used to configure a TLS client or server. -// After one has been passed to a TLS function it must not be -// modified. A Config may be reused; the tls package will also not -// modify it. -type Config = tls.Config - -type config struct { - // Rand provides the source of entropy for nonces and RSA blinding. - // If Rand is nil, TLS uses the cryptographic random reader in package - // crypto/rand. - // The Reader must be safe for use by multiple goroutines. - Rand io.Reader - - // Time returns the current time as the number of seconds since the epoch. - // If Time is nil, TLS uses time.Now. - Time func() time.Time - - // Certificates contains one or more certificate chains to present to the - // other side of the connection. The first certificate compatible with the - // peer's requirements is selected automatically. - // - // Server configurations must set one of Certificates, GetCertificate or - // GetConfigForClient. Clients doing client-authentication may set either - // Certificates or GetClientCertificate. - // - // Note: if there are multiple Certificates, and they don't have the - // optional field Leaf set, certificate selection will incur a significant - // per-handshake performance cost. - Certificates []Certificate - - // NameToCertificate maps from a certificate name to an element of - // Certificates. Note that a certificate name can be of the form - // '*.example.com' and so doesn't have to be a domain name as such. - // - // Deprecated: NameToCertificate only allows associating a single - // certificate with a given name. Leave this field nil to let the library - // select the first compatible chain from Certificates. - NameToCertificate map[string]*Certificate - - // GetCertificate returns a Certificate based on the given - // ClientHelloInfo. It will only be called if the client supplies SNI - // information or if Certificates is empty. - // - // If GetCertificate is nil or returns nil, then the certificate is - // retrieved from NameToCertificate. If NameToCertificate is nil, the - // best element of Certificates will be used. - // - // Once a Certificate is returned it should not be modified. - GetCertificate func(*ClientHelloInfo) (*Certificate, error) - - // GetClientCertificate, if not nil, is called when a server requests a - // certificate from a client. If set, the contents of Certificates will - // be ignored. - // - // If GetClientCertificate returns an error, the handshake will be - // aborted and that error will be returned. Otherwise - // GetClientCertificate must return a non-nil Certificate. If - // Certificate.Certificate is empty then no certificate will be sent to - // the server. If this is unacceptable to the server then it may abort - // the handshake. - // - // GetClientCertificate may be called multiple times for the same - // connection if renegotiation occurs or if TLS 1.3 is in use. - // - // Once a Certificate is returned it should not be modified. - GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) - - // GetConfigForClient, if not nil, is called after a ClientHello is - // received from a client. It may return a non-nil Config in order to - // change the Config that will be used to handle this connection. If - // the returned Config is nil, the original Config will be used. The - // Config returned by this callback may not be subsequently modified. - // - // If GetConfigForClient is nil, the Config passed to Server() will be - // used for all connections. - // - // If SessionTicketKey was explicitly set on the returned Config, or if - // SetSessionTicketKeys was called on the returned Config, those keys will - // be used. Otherwise, the original Config keys will be used (and possibly - // rotated if they are automatically managed). - GetConfigForClient func(*ClientHelloInfo) (*Config, error) - - // VerifyPeerCertificate, if not nil, is called after normal - // certificate verification by either a TLS client or server. It - // receives the raw ASN.1 certificates provided by the peer and also - // any verified chains that normal processing found. If it returns a - // non-nil error, the handshake is aborted and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. If normal verification is disabled by - // setting InsecureSkipVerify, or (for a server) when ClientAuth is - // RequestClientCert or RequireAnyClientCert, then this callback will - // be considered but the verifiedChains argument will always be nil. - // - // verifiedChains and its contents should not be modified. - VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - - // VerifyConnection, if not nil, is called after normal certificate - // verification and after VerifyPeerCertificate by either a TLS client - // or server. If it returns a non-nil error, the handshake is aborted - // and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. This callback will run for all connections - // regardless of InsecureSkipVerify or ClientAuth settings. - VerifyConnection func(ConnectionState) error - - // RootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *x509.CertPool - - // NextProtos is a list of supported application level protocols, in - // order of preference. If both peers support ALPN, the selected - // protocol will be one from this list, and the connection will fail - // if there is no mutually supported protocol. If NextProtos is empty - // or the peer doesn't support ALPN, the connection will succeed and - // ConnectionState.NegotiatedProtocol will be empty. - NextProtos []string - - // ServerName is used to verify the hostname on the returned - // certificates unless InsecureSkipVerify is given. It is also included - // in the client's handshake to support virtual hosting unless it is - // an IP address. - ServerName string - - // ClientAuth determines the server's policy for - // TLS Client Authentication. The default is NoClientCert. - ClientAuth ClientAuthType - - // ClientCAs defines the set of root certificate authorities - // that servers use if required to verify a client certificate - // by the policy in ClientAuth. - ClientCAs *x509.CertPool - - // InsecureSkipVerify controls whether a client verifies the server's - // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls - // accepts any certificate presented by the server and any host name in that - // certificate. In this mode, TLS is susceptible to machine-in-the-middle - // attacks unless custom verification is used. This should be used only for - // testing or in combination with VerifyConnection or VerifyPeerCertificate. - InsecureSkipVerify bool - - // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of - // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. - // - // If CipherSuites is nil, a safe default list is used. The default cipher - // suites might change over time. - CipherSuites []uint16 - - // PreferServerCipherSuites is a legacy field and has no effect. - // - // It used to control whether the server would follow the client's or the - // server's preference. Servers now select the best mutually supported - // cipher suite based on logic that takes into account inferred client - // hardware, server hardware, and security. - // - // Deprecated: PreferServerCipherSuites is ignored. - PreferServerCipherSuites bool - - // SessionTicketsDisabled may be set to true to disable session ticket and - // PSK (resumption) support. Note that on clients, session ticket support is - // also disabled if ClientSessionCache is nil. - SessionTicketsDisabled bool - - // SessionTicketKey is used by TLS servers to provide session resumption. - // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled - // with random data before the first server handshake. - // - // Deprecated: if this field is left at zero, session ticket keys will be - // automatically rotated every day and dropped after seven days. For - // customizing the rotation schedule or synchronizing servers that are - // terminating connections for the same host, use SetSessionTicketKeys. - SessionTicketKey [32]byte - - // ClientSessionCache is a cache of ClientSessionState entries for TLS - // session resumption. It is only used by clients. - ClientSessionCache ClientSessionCache - - // MinVersion contains the minimum TLS version that is acceptable. - // - // By default, TLS 1.2 is currently used as the minimum when acting as a - // client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum - // supported by this package, both as a client and as a server. - // - // The client-side default can temporarily be reverted to TLS 1.0 by - // including the value "x509sha1=1" in the GODEBUG environment variable. - // Note that this option will be removed in Go 1.19 (but it will still be - // possible to set this field to VersionTLS10 explicitly). - MinVersion uint16 - - // MaxVersion contains the maximum TLS version that is acceptable. - // - // By default, the maximum version supported by this package is used, - // which is currently TLS 1.3. - MaxVersion uint16 - - // CurvePreferences contains the elliptic curves that will be used in - // an ECDHE handshake, in preference order. If empty, the default will - // be used. The client will use the first preference as the type for - // its key share in TLS 1.3. This may change in the future. - CurvePreferences []CurveID - - // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. - // When true, the largest possible TLS record size is always used. When - // false, the size of TLS records may be adjusted in an attempt to - // improve latency. - DynamicRecordSizingDisabled bool - - // Renegotiation controls what types of renegotiation are supported. - // The default, none, is correct for the vast majority of applications. - Renegotiation RenegotiationSupport - - // KeyLogWriter optionally specifies a destination for TLS master secrets - // in NSS key log format that can be used to allow external programs - // such as Wireshark to decrypt TLS connections. - // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. - // Use of KeyLogWriter compromises security and should only be - // used for debugging. - KeyLogWriter io.Writer - - // mutex protects sessionTicketKeys and autoSessionTicketKeys. - mutex sync.RWMutex - // sessionTicketKeys contains zero or more ticket keys. If set, it means - // the keys were set with SessionTicketKey or SetSessionTicketKeys. The - // first key is used for new tickets and any subsequent keys can be used to - // decrypt old tickets. The slice contents are not protected by the mutex - // and are immutable. - sessionTicketKeys []ticketKey - // autoSessionTicketKeys is like sessionTicketKeys but is owned by the - // auto-rotation logic. See Config.ticketKeys. - autoSessionTicketKeys []ticketKey -} - -type ExtraConfig struct { - // If Enable0RTT is enabled, the client will be allowed to send early data when resuming a session. - // - // It has no meaning on the client. - Enable0RTT bool - - // GetAppDataForSessionTicket requests application data to be sent with a session ticket. - // - // It has no meaning on the client. - GetAppDataForSessionTicket func() []byte - - // The Accept0RTT callback is called when the client offers 0-RTT. - // The server then has to decide if it wants to accept or reject 0-RTT. - // It is only used for servers. - Accept0RTT func(appData []byte) bool - - // Is called when the client saves a session ticket to the session ticket. - // This gives the application the opportunity to save some data along with the ticket, - // which can be restored when the session ticket is used. - GetAppDataForSessionState func() []byte - - // Is called when the client uses a session ticket. - // Restores the application data that was saved earlier on GetAppDataForSessionTicket. - SetAppDataFromSessionState func([]byte) -} - -// Clone clones. -func (c *ExtraConfig) Clone() *ExtraConfig { - return &ExtraConfig{ - Enable0RTT: c.Enable0RTT, - GetAppDataForSessionTicket: c.GetAppDataForSessionTicket, - Accept0RTT: c.Accept0RTT, - GetAppDataForSessionState: c.GetAppDataForSessionState, - SetAppDataFromSessionState: c.SetAppDataFromSessionState, - } -} - -const ( - // ticketKeyNameLen is the number of bytes of identifier that is prepended to - // an encrypted session ticket in order to identify the key used to encrypt it. - ticketKeyNameLen = 16 - - // ticketKeyLifetime is how long a ticket key remains valid and can be used to - // resume a client connection. - ticketKeyLifetime = 7 * 24 * time.Hour // 7 days - - // ticketKeyRotation is how often the server should rotate the session ticket key - // that is used for new tickets. - ticketKeyRotation = 24 * time.Hour -) - -// ticketKey is the internal representation of a session ticket key. -type ticketKey struct { - // keyName is an opaque byte string that serves to identify the session - // ticket key. It's exposed as plaintext in every session ticket. - keyName [ticketKeyNameLen]byte - aesKey [16]byte - hmacKey [16]byte - // created is the time at which this ticket key was created. See Config.ticketKeys. - created time.Time -} - -// ticketKeyFromBytes converts from the external representation of a session -// ticket key to a ticketKey. Externally, session ticket keys are 32 random -// bytes and this function expands that into sufficient name and key material. -func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { - hashed := sha512.Sum512(b[:]) - copy(key.keyName[:], hashed[:ticketKeyNameLen]) - copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) - copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) - key.created = c.time() - return key -} - -// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session -// ticket, and the lifetime we set for tickets we send. -const maxSessionTicketLifetime = 7 * 24 * time.Hour - -// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is -// being used concurrently by a TLS client or server. -func (c *config) Clone() *config { - if c == nil { - return nil - } - c.mutex.RLock() - defer c.mutex.RUnlock() - return &config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - GetClientCertificate: c.GetClientCertificate, - GetConfigForClient: c.GetConfigForClient, - VerifyPeerCertificate: c.VerifyPeerCertificate, - VerifyConnection: c.VerifyConnection, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - KeyLogWriter: c.KeyLogWriter, - sessionTicketKeys: c.sessionTicketKeys, - autoSessionTicketKeys: c.autoSessionTicketKeys, - } -} - -// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was -// randomized for backwards compatibility but is not in use. -var deprecatedSessionTicketKey = []byte("DEPRECATED") - -// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is -// randomized if empty, and that sessionTicketKeys is populated from it otherwise. -func (c *config) initLegacySessionTicketKeyRLocked() { - // Don't write if SessionTicketKey is already defined as our deprecated string, - // or if it is defined by the user but sessionTicketKeys is already set. - if c.SessionTicketKey != [32]byte{} && - (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { - return - } - - // We need to write some data, so get an exclusive lock and re-check any conditions. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - if c.SessionTicketKey == [32]byte{} { - if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { - panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) - } - // Write the deprecated prefix at the beginning so we know we created - // it. This key with the DEPRECATED prefix isn't used as an actual - // session ticket key, and is only randomized in case the application - // reuses it for some reason. - copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) - } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { - c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} - } - -} - -// ticketKeys returns the ticketKeys for this connection. -// If configForClient has explicitly set keys, those will -// be returned. Otherwise, the keys on c will be used and -// may be rotated if auto-managed. -// During rotation, any expired session ticket keys are deleted from -// c.sessionTicketKeys. If the session ticket key that is currently -// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) -// is not fresh, then a new session ticket key will be -// created and prepended to c.sessionTicketKeys. -func (c *config) ticketKeys(configForClient *config) []ticketKey { - // If the ConfigForClient callback returned a Config with explicitly set - // keys, use those, otherwise just use the original Config. - if configForClient != nil { - configForClient.mutex.RLock() - if configForClient.SessionTicketsDisabled { - return nil - } - configForClient.initLegacySessionTicketKeyRLocked() - if len(configForClient.sessionTicketKeys) != 0 { - ret := configForClient.sessionTicketKeys - configForClient.mutex.RUnlock() - return ret - } - configForClient.mutex.RUnlock() - } - - c.mutex.RLock() - defer c.mutex.RUnlock() - if c.SessionTicketsDisabled { - return nil - } - c.initLegacySessionTicketKeyRLocked() - if len(c.sessionTicketKeys) != 0 { - return c.sessionTicketKeys - } - // Fast path for the common case where the key is fresh enough. - if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { - return c.autoSessionTicketKeys - } - - // autoSessionTicketKeys are managed by auto-rotation. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - // Re-check the condition in case it changed since obtaining the new lock. - if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { - var newKey [32]byte - if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { - panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) - } - valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) - valid = append(valid, c.ticketKeyFromBytes(newKey)) - for _, k := range c.autoSessionTicketKeys { - // While rotating the current key, also remove any expired ones. - if c.time().Sub(k.created) < ticketKeyLifetime { - valid = append(valid, k) - } - } - c.autoSessionTicketKeys = valid - } - return c.autoSessionTicketKeys -} - -// SetSessionTicketKeys updates the session ticket keys for a server. -// -// The first key will be used when creating new tickets, while all keys can be -// used for decrypting tickets. It is safe to call this function while the -// server is running in order to rotate the session ticket keys. The function -// will panic if keys is empty. -// -// Calling this function will turn off automatic session ticket key rotation. -// -// If multiple servers are terminating connections for the same host they should -// all have the same session ticket keys. If the session ticket keys leaks, -// previously recorded and future TLS connections using those keys might be -// compromised. -func (c *config) SetSessionTicketKeys(keys [][32]byte) { - if len(keys) == 0 { - panic("tls: keys must have at least one key") - } - - newKeys := make([]ticketKey, len(keys)) - for i, bytes := range keys { - newKeys[i] = c.ticketKeyFromBytes(bytes) - } - - c.mutex.Lock() - c.sessionTicketKeys = newKeys - c.mutex.Unlock() -} - -func (c *config) rand() io.Reader { - r := c.Rand - if r == nil { - return rand.Reader - } - return r -} - -func (c *config) time() time.Time { - t := c.Time - if t == nil { - t = time.Now - } - return t() -} - -func (c *config) cipherSuites() []uint16 { - if needFIPS() { - return fipsCipherSuites(c) - } - if c.CipherSuites != nil { - return c.CipherSuites - } - return defaultCipherSuites -} - -var supportedVersions = []uint16{ - VersionTLS13, - VersionTLS12, - VersionTLS11, - VersionTLS10, -} - -// roleClient and roleServer are meant to call supportedVersions and parents -// with more readability at the callsite. -const roleClient = true -const roleServer = false - -func (c *config) supportedVersions(isClient bool) []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) { - continue - } - if (c == nil || c.MinVersion == 0) && - isClient && v < VersionTLS12 { - continue - } - if c != nil && c.MinVersion != 0 && v < c.MinVersion { - continue - } - if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -func (c *config) maxSupportedVersion(isClient bool) uint16 { - supportedVersions := c.supportedVersions(isClient) - if len(supportedVersions) == 0 { - return 0 - } - return supportedVersions[0] -} - -// supportedVersionsFromMax returns a list of supported versions derived from a -// legacy maximum version value. Note that only versions supported by this -// library are returned. Any newer peer will use supportedVersions anyway. -func supportedVersionsFromMax(maxVersion uint16) []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if v > maxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} - -func (c *config) curvePreferences() []CurveID { - if needFIPS() { - return fipsCurvePreferences(c) - } - if c == nil || len(c.CurvePreferences) == 0 { - return defaultCurvePreferences - } - return c.CurvePreferences -} - -func (c *config) supportsCurve(curve CurveID) bool { - for _, cc := range c.curvePreferences() { - if cc == curve { - return true - } - } - return false -} - -// mutualVersion returns the protocol version to use given the advertised -// versions of the peer. Priority is given to the peer preference order. -func (c *config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) { - supportedVersions := c.supportedVersions(isClient) - for _, peerVersion := range peerVersions { - for _, v := range supportedVersions { - if v == peerVersion { - return v, true - } - } - } - return 0, false -} - -var errNoCertificates = errors.New("tls: no certificates configured") - -// getCertificate returns the best certificate for the given ClientHelloInfo, -// defaulting to the first element of c.Certificates. -func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { - if c.GetCertificate != nil && - (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { - cert, err := c.GetCertificate(clientHello) - if cert != nil || err != nil { - return cert, err - } - } - - if len(c.Certificates) == 0 { - return nil, errNoCertificates - } - - if len(c.Certificates) == 1 { - // There's only one choice, so no point doing any work. - return &c.Certificates[0], nil - } - - if c.NameToCertificate != nil { - name := strings.ToLower(clientHello.ServerName) - if cert, ok := c.NameToCertificate[name]; ok { - return cert, nil - } - if len(name) > 0 { - labels := strings.Split(name, ".") - labels[0] = "*" - wildcardName := strings.Join(labels, ".") - if cert, ok := c.NameToCertificate[wildcardName]; ok { - return cert, nil - } - } - } - - for _, cert := range c.Certificates { - if err := clientHello.SupportsCertificate(&cert); err == nil { - return &cert, nil - } - } - - // If nothing matches, return the first certificate. - return &c.Certificates[0], nil -} - -// SupportsCertificate returns nil if the provided certificate is supported by -// the client that sent the ClientHello. Otherwise, it returns an error -// describing the reason for the incompatibility. -// -// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate -// callback, this method will take into account the associated Config. Note that -// if GetConfigForClient returns a different Config, the change can't be -// accounted for by this method. -// -// This function will call x509.ParseCertificate unless c.Leaf is set, which can -// incur a significant performance cost. -func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { - // Note we don't currently support certificate_authorities nor - // signature_algorithms_cert, and don't check the algorithms of the - // signatures on the chain (which anyway are a SHOULD, see RFC 8446, - // Section 4.4.2.2). - - config := chi.config - if config == nil { - config = &Config{} - } - conf := fromConfig(config) - vers, ok := conf.mutualVersion(roleServer, chi.SupportedVersions) - if !ok { - return errors.New("no mutually supported protocol versions") - } - - // If the client specified the name they are trying to connect to, the - // certificate needs to be valid for it. - if chi.ServerName != "" { - x509Cert, err := leafCertificate(c) - if err != nil { - return fmt.Errorf("failed to parse certificate: %w", err) - } - if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { - return fmt.Errorf("certificate is not valid for requested server name: %w", err) - } - } - - // supportsRSAFallback returns nil if the certificate and connection support - // the static RSA key exchange, and unsupported otherwise. The logic for - // supporting static RSA is completely disjoint from the logic for - // supporting signed key exchanges, so we just check it as a fallback. - supportsRSAFallback := func(unsupported error) error { - // TLS 1.3 dropped support for the static RSA key exchange. - if vers == VersionTLS13 { - return unsupported - } - // The static RSA key exchange works by decrypting a challenge with the - // RSA private key, not by signing, so check the PrivateKey implements - // crypto.Decrypter, like *rsa.PrivateKey does. - if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { - if _, ok := priv.Public().(*rsa.PublicKey); !ok { - return unsupported - } - } else { - return unsupported - } - // Finally, there needs to be a mutual cipher suite that uses the static - // RSA key exchange instead of ECDHE. - rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - return false - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if rsaCipherSuite == nil { - return unsupported - } - return nil - } - - // If the client sent the signature_algorithms extension, ensure it supports - // schemes we can use with this certificate and TLS version. - if len(chi.SignatureSchemes) > 0 { - if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { - return supportsRSAFallback(err) - } - } - - // In TLS 1.3 we are done because supported_groups is only relevant to the - // ECDHE computation, point format negotiation is removed, cipher suites are - // only relevant to the AEAD choice, and static RSA does not exist. - if vers == VersionTLS13 { - return nil - } - - // The only signed key exchange we support is ECDHE. - if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { - return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) - } - - var ecdsaCipherSuite bool - if priv, ok := c.PrivateKey.(crypto.Signer); ok { - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - var curve CurveID - switch pub.Curve { - case elliptic.P256(): - curve = CurveP256 - case elliptic.P384(): - curve = CurveP384 - case elliptic.P521(): - curve = CurveP521 - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - var curveOk bool - for _, c := range chi.SupportedCurves { - if c == curve && conf.supportsCurve(c) { - curveOk = true - break - } - } - if !curveOk { - return errors.New("client doesn't support certificate curve") - } - ecdsaCipherSuite = true - case ed25519.PublicKey: - if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { - return errors.New("connection doesn't support Ed25519") - } - ecdsaCipherSuite = true - case *rsa.PublicKey: - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - } else { - return supportsRSAFallback(unsupportedCertificateError(c)) - } - - // Make sure that there is a mutually supported cipher suite that works with - // this certificate. Cipher suite selection will then apply the logic in - // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. - cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE == 0 { - return false - } - if c.flags&suiteECSign != 0 { - if !ecdsaCipherSuite { - return false - } - } else { - if ecdsaCipherSuite { - return false - } - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if cipherSuite == nil { - return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) - } - - return nil -} - -// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate -// from the CommonName and SubjectAlternateName fields of each of the leaf -// certificates. -// -// Deprecated: NameToCertificate only allows associating a single certificate -// with a given name. Leave that field nil to let the library select the first -// compatible chain from Certificates. -func (c *config) BuildNameToCertificate() { - c.NameToCertificate = make(map[string]*Certificate) - for i := range c.Certificates { - cert := &c.Certificates[i] - x509Cert, err := leafCertificate(cert) - if err != nil { - continue - } - // If SANs are *not* present, some clients will consider the certificate - // valid for the name in the Common Name. - if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { - c.NameToCertificate[x509Cert.Subject.CommonName] = cert - } - for _, san := range x509Cert.DNSNames { - c.NameToCertificate[san] = cert - } - } -} - -const ( - keyLogLabelTLS12 = "CLIENT_RANDOM" - keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" - keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" -) - -func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { - if c.KeyLogWriter == nil { - return nil - } - - logLine := fmt.Appendf(nil, "%s %x %x\n", label, clientRandom, secret) - - writerMutex.Lock() - _, err := c.KeyLogWriter.Write(logLine) - writerMutex.Unlock() - - return err -} - -// writerMutex protects all KeyLogWriters globally. It is rarely enabled, -// and is only for debugging, so a global mutex saves space. -var writerMutex sync.Mutex - -// A Certificate is a chain of one or more certificates, leaf first. -type Certificate = tls.Certificate - -// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing -// the corresponding c.Certificate[0]. -func leafCertificate(c *Certificate) (*x509.Certificate, error) { - if c.Leaf != nil { - return c.Leaf, nil - } - return x509.ParseCertificate(c.Certificate[0]) -} - -type handshakeMessage interface { - marshal() ([]byte, error) - unmarshal([]byte) bool -} - -// lruSessionCache is a ClientSessionCache implementation that uses an LRU -// caching strategy. -type lruSessionCache struct { - sync.Mutex - - m map[string]*list.Element - q *list.List - capacity int -} - -type lruSessionCacheEntry struct { - sessionKey string - state *ClientSessionState -} - -// NewLRUClientSessionCache returns a ClientSessionCache with the given -// capacity that uses an LRU strategy. If capacity is < 1, a default capacity -// is used instead. -func NewLRUClientSessionCache(capacity int) ClientSessionCache { - const defaultSessionCacheCapacity = 64 - - if capacity < 1 { - capacity = defaultSessionCacheCapacity - } - return &lruSessionCache{ - m: make(map[string]*list.Element), - q: list.New(), - capacity: capacity, - } -} - -// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry -// corresponding to sessionKey is removed from the cache instead. -func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - if cs == nil { - c.q.Remove(elem) - delete(c.m, sessionKey) - } else { - entry := elem.Value.(*lruSessionCacheEntry) - entry.state = cs - c.q.MoveToFront(elem) - } - return - } - - if c.q.Len() < c.capacity { - entry := &lruSessionCacheEntry{sessionKey, cs} - c.m[sessionKey] = c.q.PushFront(entry) - return - } - - elem := c.q.Back() - entry := elem.Value.(*lruSessionCacheEntry) - delete(c.m, entry.sessionKey) - entry.sessionKey = sessionKey - entry.state = cs - c.q.MoveToFront(elem) - c.m[sessionKey] = elem -} - -// Get returns the ClientSessionState value associated with a given key. It -// returns (nil, false) if no value is found. -func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - c.q.MoveToFront(elem) - return elem.Value.(*lruSessionCacheEntry).state, true - } - return nil, false -} - -var emptyConfig Config - -func defaultConfig() *Config { - return &emptyConfig -} - -func unexpectedMessageError(wanted, got any) error { - return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) -} - -func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { - for _, s := range supportedSignatureAlgorithms { - if s == sigAlg { - return true - } - } - return false -} - -// CertificateVerificationError is returned when certificate verification fails during the handshake. -type CertificateVerificationError = tls.CertificateVerificationError diff --git a/vendor/github.com/quic-go/qtls-go1-20/conn.go b/vendor/github.com/quic-go/qtls-go1-20/conn.go deleted file mode 100644 index b7ebdb0a7..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/conn.go +++ /dev/null @@ -1,1643 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// TLS low level connection and record layer - -package qtls - -import ( - "bytes" - "context" - "crypto/cipher" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "time" -) - -// A Conn represents a secured connection. -// It implements the net.Conn interface. -type Conn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - quic *quicState // nil for non-QUIC connections - - // isHandshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // isHandshakeComplete is true implies handshakeErr == nil. - isHandshakeComplete atomic.Bool - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *config // configuration passed to constructor - extraConfig *ExtraConfig - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - handshakes int - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // activeCertHandles contains the cache handles to certificates in - // peerCertificates that are used to track active references. - activeCertHandles []*activeCert - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // resumptionSecret is the resumption_master_secret for handling - // or sending NewSessionTicket messages. - resumptionSecret []byte - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []ticketKey - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn - rawInput bytes.Buffer // raw input, starting with a record header - input bytes.Reader // application data waiting to be read, from rawInput.Next - hand bytes.Buffer // handshake data waiting to be read - buffering bool // whether records are buffered in sendBuf - sendBuf []byte // a buffer of records waiting to be sent - - // bytesSent counts the bytes of application data sent. - // packetsSent counts packets. - bytesSent int64 - packetsSent int64 - - // retryCount counts the number of consecutive non-advancing records - // received by Conn.readRecord. That is, records that neither advance the - // handshake, nor deliver application data. Protected by in.Mutex. - retryCount int - - // activeCall indicates whether Close has been call in the low bit. - // the rest of the bits are the number of goroutines in Conn.Write. - activeCall atomic.Int32 - - tmp [16]byte -} - -// Access to net.Conn methods. -// Cannot just embed net.Conn because that would -// export the struct field too. - -// LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline sets the read and write deadlines associated with the connection. -// A zero value for t means Read and Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying connection. -// A zero value for t means Read will not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying connection. -// A zero value for t means Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -// NetConn returns the underlying connection that is wrapped by c. -// Note that writing to or reading from this connection directly will corrupt the -// TLS session. -func (c *Conn) NetConn() net.Conn { - return c.conn -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher any // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher any // next encryption state - nextMac hash.Hash // next MAC algorithm - - level QUICEncryptionLevel // current QUIC encryption level - trafficSecret []byte // current TLS 1.3 traffic secret -} - -type permanentError struct { - err net.Error -} - -func (e *permanentError) Error() string { return e.err.Error() } -func (e *permanentError) Unwrap() error { return e.err } -func (e *permanentError) Timeout() bool { return e.err.Timeout() } -func (e *permanentError) Temporary() bool { return false } - -func (hc *halfConn) setErrorLocked(err error) error { - if e, ok := err.(net.Error); ok { - hc.err = &permanentError{err: e} - } else { - hc.err = err - } - return hc.err -} - -// prepareCipherSpec sets the encryption and MAC states -// that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) { - hc.version = version - hc.nextCipher = cipher - hc.nextMac = mac -} - -// changeCipherSpec changes the encryption and MAC states -// to the ones previously passed to prepareCipherSpec. -func (hc *halfConn) changeCipherSpec() error { - if hc.nextCipher == nil || hc.version == VersionTLS13 { - return alertInternalError - } - hc.cipher = hc.nextCipher - hc.mac = hc.nextMac - hc.nextCipher = nil - hc.nextMac = nil - for i := range hc.seq { - hc.seq[i] = 0 - } - return nil -} - -func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) { - hc.trafficSecret = secret - hc.level = level - key, iv := suite.trafficKey(secret) - hc.cipher = suite.aead(key, iv) - for i := range hc.seq { - hc.seq[i] = 0 - } -} - -// incSeq increments the sequence number. -func (hc *halfConn) incSeq() { - for i := 7; i >= 0; i-- { - hc.seq[i]++ - if hc.seq[i] != 0 { - return - } - } - - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. - panic("TLS: sequence number wraparound") -} - -// explicitNonceLen returns the number of bytes of explicit nonce or IV included -// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 -// and in certain AEAD modes in TLS 1.2. -func (hc *halfConn) explicitNonceLen() int { - if hc.cipher == nil { - return 0 - } - - switch c := hc.cipher.(type) { - case cipher.Stream: - return 0 - case aead: - return c.explicitNonceLen() - case cbcMode: - // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. - if hc.version >= VersionTLS11 { - return c.BlockSize() - } - return 0 - default: - panic("unknown cipher type") - } -} - -// extractPadding returns, in constant time, the length of the padding to remove -// from the end of payload. It also returns a byte which is equal to 255 if the -// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. -func extractPadding(payload []byte) (toRemove int, good byte) { - if len(payload) < 1 { - return 0, 0 - } - - paddingLen := payload[len(payload)-1] - t := uint(len(payload)-1) - uint(paddingLen) - // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good = byte(int32(^t) >> 31) - - // The maximum possible padding length plus the actual length field - toCheck := 256 - // The length of the padded data is public, so we can use an if here - if toCheck > len(payload) { - toCheck = len(payload) - } - - for i := 0; i < toCheck; i++ { - t := uint(paddingLen) - uint(i) - // if i <= paddingLen then the MSB of t is zero - mask := byte(int32(^t) >> 31) - b := payload[len(payload)-1-i] - good &^= mask&paddingLen ^ mask&b - } - - // We AND together the bits of good and replicate the result across - // all the bits. - good &= good << 4 - good &= good << 2 - good &= good << 1 - good = uint8(int8(good) >> 7) - - // Zero the padding length on error. This ensures any unchecked bytes - // are included in the MAC. Otherwise, an attacker that could - // distinguish MAC failures from padding failures could mount an attack - // similar to POODLE in SSL 3.0: given a good ciphertext that uses a - // full block's worth of padding, replace the final block with another - // block. If the MAC check passed but the padding check failed, the - // last byte of that block decrypted to the block size. - // - // See also macAndPaddingGood logic below. - paddingLen &= good - - toRemove = int(paddingLen) + 1 - return -} - -func roundUp(a, b int) int { - return a + (b-a%b)%b -} - -// cbcMode is an interface for block ciphers using cipher block chaining. -type cbcMode interface { - cipher.BlockMode - SetIV([]byte) -} - -// decrypt authenticates and decrypts the record if protection is active at -// this stage. The returned plaintext might overlap with the input. -func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { - var plaintext []byte - typ := recordType(record[0]) - payload := record[recordHeaderLen:] - - // In TLS 1.3, change_cipher_spec messages are to be ignored without being - // decrypted. See RFC 8446, Appendix D.4. - if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { - return payload, typ, nil - } - - paddingGood := byte(255) - paddingLen := 0 - - explicitNonceLen := hc.explicitNonceLen() - - if hc.cipher != nil { - switch c := hc.cipher.(type) { - case cipher.Stream: - c.XORKeyStream(payload, payload) - case aead: - if len(payload) < explicitNonceLen { - return nil, 0, alertBadRecordMAC - } - nonce := payload[:explicitNonceLen] - if len(nonce) == 0 { - nonce = hc.seq[:] - } - payload = payload[explicitNonceLen:] - - var additionalData []byte - if hc.version == VersionTLS13 { - additionalData = record[:recordHeaderLen] - } else { - additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:3]...) - n := len(payload) - c.Overhead() - additionalData = append(additionalData, byte(n>>8), byte(n)) - } - - var err error - plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) - if err != nil { - return nil, 0, alertBadRecordMAC - } - case cbcMode: - blockSize := c.BlockSize() - minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) - if len(payload)%blockSize != 0 || len(payload) < minPayload { - return nil, 0, alertBadRecordMAC - } - - if explicitNonceLen > 0 { - c.SetIV(payload[:explicitNonceLen]) - payload = payload[explicitNonceLen:] - } - c.CryptBlocks(payload, payload) - - // In a limited attempt to protect against CBC padding oracles like - // Lucky13, the data past paddingLen (which is secret) is passed to - // the MAC function as extra data, to be fed into the HMAC after - // computing the digest. This makes the MAC roughly constant time as - // long as the digest computation is constant time and does not - // affect the subsequent write, modulo cache effects. - paddingLen, paddingGood = extractPadding(payload) - default: - panic("unknown cipher type") - } - - if hc.version == VersionTLS13 { - if typ != recordTypeApplicationData { - return nil, 0, alertUnexpectedMessage - } - if len(plaintext) > maxPlaintext+1 { - return nil, 0, alertRecordOverflow - } - // Remove padding and find the ContentType scanning from the end. - for i := len(plaintext) - 1; i >= 0; i-- { - if plaintext[i] != 0 { - typ = recordType(plaintext[i]) - plaintext = plaintext[:i] - break - } - if i == 0 { - return nil, 0, alertUnexpectedMessage - } - } - } - } else { - plaintext = payload - } - - if hc.mac != nil { - macSize := hc.mac.Size() - if len(payload) < macSize { - return nil, 0, alertBadRecordMAC - } - - n := len(payload) - macSize - paddingLen - n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } - record[3] = byte(n >> 8) - record[4] = byte(n) - remoteMAC := payload[n : n+macSize] - localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) - - // This is equivalent to checking the MACs and paddingGood - // separately, but in constant-time to prevent distinguishing - // padding failures from MAC failures. Depending on what value - // of paddingLen was returned on bad padding, distinguishing - // bad MAC from bad padding can lead to an attack. - // - // See also the logic at the end of extractPadding. - macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) - if macAndPaddingGood != 1 { - return nil, 0, alertBadRecordMAC - } - - plaintext = payload[:n] - } - - hc.incSeq() - return plaintext, typ, nil -} - -// sliceForAppend extends the input slice by n bytes. head is the full extended -// slice, while tail is the appended part. If the original slice has sufficient -// capacity no allocation is performed. -func sliceForAppend(in []byte, n int) (head, tail []byte) { - if total := len(in) + n; cap(in) >= total { - head = in[:total] - } else { - head = make([]byte, total) - copy(head, in) - } - tail = head[len(in):] - return -} - -// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which must already contain the record header. -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - if hc.cipher == nil { - return append(record, payload...), nil - } - - var explicitNonce []byte - if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { - record, explicitNonce = sliceForAppend(record, explicitNonceLen) - if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { - // The AES-GCM construction in TLS has an explicit nonce so that the - // nonce can be random. However, the nonce is only 8 bytes which is - // too small for a secure, random nonce. Therefore we use the - // sequence number as the nonce. The 3DES-CBC construction also has - // an 8 bytes nonce but its nonces must be unpredictable (see RFC - // 5246, Appendix F.3), forcing us to use randomness. That's not - // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its similarly small block size - // (see the Sweet32 attack). - copy(explicitNonce, hc.seq[:]) - } else { - if _, err := io.ReadFull(rand, explicitNonce); err != nil { - return nil, err - } - } - } - - var dst []byte - switch c := hc.cipher.(type) { - case cipher.Stream: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - record, dst = sliceForAppend(record, len(payload)+len(mac)) - c.XORKeyStream(dst[:len(payload)], payload) - c.XORKeyStream(dst[len(payload):], mac) - case aead: - nonce := explicitNonce - if len(nonce) == 0 { - nonce = hc.seq[:] - } - - if hc.version == VersionTLS13 { - record = append(record, payload...) - - // Encrypt the actual ContentType and replace the plaintext one. - record = append(record, record[0]) - record[0] = byte(recordTypeApplicationData) - - n := len(payload) + 1 + c.Overhead() - record[3] = byte(n >> 8) - record[4] = byte(n) - - record = c.Seal(record[:recordHeaderLen], - nonce, record[recordHeaderLen:], record[:recordHeaderLen]) - } else { - additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:recordHeaderLen]...) - record = c.Seal(record, nonce, payload, additionalData) - } - case cbcMode: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - blockSize := c.BlockSize() - plaintextLen := len(payload) + len(mac) - paddingLen := blockSize - plaintextLen%blockSize - record, dst = sliceForAppend(record, plaintextLen+paddingLen) - copy(dst, payload) - copy(dst[len(payload):], mac) - for i := plaintextLen; i < len(dst); i++ { - dst[i] = byte(paddingLen - 1) - } - if len(explicitNonce) > 0 { - c.SetIV(explicitNonce) - } - c.CryptBlocks(dst, dst) - default: - panic("unknown cipher type") - } - - // Update length to include nonce, MAC and any block padding needed. - n := len(record) - recordHeaderLen - record[3] = byte(n >> 8) - record[4] = byte(n) - hc.incSeq() - - return record, nil -} - -// RecordHeaderError is returned when a TLS record header is invalid. -type RecordHeaderError struct { - // Msg contains a human readable string that describes the error. - Msg string - // RecordHeader contains the five bytes of TLS record header that - // triggered the error. - RecordHeader [5]byte - // Conn provides the underlying net.Conn in the case that a client - // sent an initial handshake that didn't look like TLS. - // It is nil if there's already been a handshake or a TLS alert has - // been written to the connection. - Conn net.Conn -} - -func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } - -func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { - err.Msg = msg - err.Conn = conn - copy(err.RecordHeader[:], c.rawInput.Bytes()) - return err -} - -func (c *Conn) readRecord() error { - return c.readRecordOrCCS(false) -} - -func (c *Conn) readChangeCipherSpec() error { - return c.readRecordOrCCS(true) -} - -// readRecordOrCCS reads one or more TLS records from the connection and -// updates the record layer state. Some invariants: -// - c.in must be locked -// - c.input must be empty -// -// During the handshake one and only one of the following will happen: -// - c.hand grows -// - c.in.changeCipherSpec is called -// - an error is returned -// -// After the handshake one and only one of the following will happen: -// - c.hand grows -// - c.input is set -// - an error is returned -func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { - if c.in.err != nil { - return c.in.err - } - handshakeComplete := c.isHandshakeComplete.Load() - - // This function modifies c.rawInput, which owns the c.input memory. - if c.input.Len() != 0 { - return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) - } - c.input.Reset(nil) - - if c.quic != nil { - return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport")) - } - - // Read header, payload. - if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { - // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify - // is an error, but popular web sites seem to do this, so we accept it - // if and only if at the record boundary. - if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { - err = io.EOF - } - if e, ok := err.(net.Error); !ok || !e.Temporary() { - c.in.setErrorLocked(err) - } - return err - } - hdr := c.rawInput.Bytes()[:recordHeaderLen] - typ := recordType(hdr[0]) - - // No valid TLS record has a type of 0x80, however SSLv2 handshakes - // start with a uint16 length where the MSB is set and the first record - // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests - // an SSLv2 client. - if !handshakeComplete && typ == 0x80 { - c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) - } - - vers := uint16(hdr[1])<<8 | uint16(hdr[2]) - n := int(hdr[3])<<8 | int(hdr[4]) - if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { - c.sendAlert(alertProtocolVersion) - msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if !c.haveVers { - // First message, be extra suspicious: this might not be a TLS - // client. Bail out before reading a full 'body', if possible. - // The current max version is 3.3 so if the version is >= 16.0, - // it's probably not real. - if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { - return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) - } - } - if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { - c.sendAlert(alertRecordOverflow) - msg := fmt.Sprintf("oversized record received with length %d", n) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { - if e, ok := err.(net.Error); !ok || !e.Temporary() { - c.in.setErrorLocked(err) - } - return err - } - - // Process message. - record := c.rawInput.Next(recordHeaderLen + n) - data, typ, err := c.in.decrypt(record) - if err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - if len(data) > maxPlaintext { - return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) - } - - // Application Data messages are always protected. - if c.in.cipher == nil && typ == recordTypeApplicationData { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { - // This is a state-advancing message: reset the retry count. - c.retryCount = 0 - } - - // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. - if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - switch typ { - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - - case recordTypeAlert: - if c.quic != nil { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if len(data) != 2 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if alert(data[1]) == alertCloseNotify { - return c.in.setErrorLocked(io.EOF) - } - if c.vers == VersionTLS13 { - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - } - switch data[0] { - case alertLevelWarning: - // Drop the record on the floor and retry. - return c.retryReadRecord(expectChangeCipherSpec) - case alertLevelError: - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - case recordTypeChangeCipherSpec: - if len(data) != 1 || data[0] != 1 { - return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) - } - // Handshake messages are not allowed to fragment across the CCS. - if c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // In TLS 1.3, change_cipher_spec records are ignored until the - // Finished. See RFC 8446, Appendix D.4. Note that according to Section - // 5, a server can send a ChangeCipherSpec before its ServerHello, when - // c.vers is still unset. That's not useful though and suspicious if the - // server then selects a lower protocol version, so don't allow that. - if c.vers == VersionTLS13 { - return c.retryReadRecord(expectChangeCipherSpec) - } - if !expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if err := c.in.changeCipherSpec(); err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - - case recordTypeApplicationData: - if !handshakeComplete || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // Some OpenSSL servers send empty records in order to randomize the - // CBC IV. Ignore a limited number of empty records. - if len(data) == 0 { - return c.retryReadRecord(expectChangeCipherSpec) - } - // Note that data is owned by c.rawInput, following the Next call above, - // to avoid copying the plaintext. This is safe because c.rawInput is - // not read from or written to until c.input is drained. - c.input.Reset(data) - - case recordTypeHandshake: - if len(data) == 0 || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - c.hand.Write(data) - } - - return nil -} - -// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like -// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. -func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many ignored records")) - } - return c.readRecordOrCCS(expectChangeCipherSpec) -} - -// atLeastReader reads from R, stopping with EOF once at least N bytes have been -// read. It is different from an io.LimitedReader in that it doesn't cut short -// the last Read call, and in that it considers an early EOF an error. -type atLeastReader struct { - R io.Reader - N int64 -} - -func (r *atLeastReader) Read(p []byte) (int, error) { - if r.N <= 0 { - return 0, io.EOF - } - n, err := r.R.Read(p) - r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 - if r.N > 0 && err == io.EOF { - return n, io.ErrUnexpectedEOF - } - if r.N <= 0 && err == nil { - return n, io.EOF - } - return n, err -} - -// readFromUntil reads from r into c.rawInput until c.rawInput contains -// at least n bytes or else returns an error. -func (c *Conn) readFromUntil(r io.Reader, n int) error { - if c.rawInput.Len() >= n { - return nil - } - needs := n - c.rawInput.Len() - // There might be extra input waiting on the wire. Make a best effort - // attempt to fetch it so that it can be used in (*Conn).Read to - // "predict" closeNotify alerts. - c.rawInput.Grow(needs + bytes.MinRead) - _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) - return err -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlertLocked(err alert) error { - if c.quic != nil { - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) - } - switch err { - case alertNoRenegotiation, alertCloseNotify: - c.tmp[0] = alertLevelWarning - default: - c.tmp[0] = alertLevelError - } - c.tmp[1] = byte(err) - - _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) - if err == alertCloseNotify { - // closeNotify is a special case in that it isn't an error. - return writeErr - } - - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlert(err alert) error { - c.out.Lock() - defer c.out.Unlock() - return c.sendAlertLocked(err) -} - -const ( - // tcpMSSEstimate is a conservative estimate of the TCP maximum segment - // size (MSS). A constant is used, rather than querying the kernel for - // the actual MSS, to avoid complexity. The value here is the IPv6 - // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 - // bytes) and a TCP header with timestamps (32 bytes). - tcpMSSEstimate = 1208 - - // recordSizeBoostThreshold is the number of bytes of application data - // sent after which the TLS record size will be increased to the - // maximum. - recordSizeBoostThreshold = 128 * 1024 -) - -// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the -// next application data record. There is the following trade-off: -// -// - For latency-sensitive applications, such as web browsing, each TLS -// record should fit in one TCP segment. -// - For throughput-sensitive applications, such as large file transfers, -// larger TLS records better amortize framing and encryption overheads. -// -// A simple heuristic that works well in practice is to use small records for -// the first 1MB of data, then use larger records for subsequent data, and -// reset back to smaller records after the connection becomes idle. See "High -// Performance Web Networking", Chapter 4, or: -// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ -// -// In the interests of simplicity and determinism, this code does not attempt -// to reset the record size once the connection is idle, however. -func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { - if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { - return maxPlaintext - } - - if c.bytesSent >= recordSizeBoostThreshold { - return maxPlaintext - } - - // Subtract TLS overheads to get the maximum payload size. - payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() - if c.out.cipher != nil { - switch ciph := c.out.cipher.(type) { - case cipher.Stream: - payloadBytes -= c.out.mac.Size() - case cipher.AEAD: - payloadBytes -= ciph.Overhead() - case cbcMode: - blockSize := ciph.BlockSize() - // The payload must fit in a multiple of blockSize, with - // room for at least one padding byte. - payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 - // The MAC is appended before padding so affects the - // payload size directly. - payloadBytes -= c.out.mac.Size() - default: - panic("unknown cipher type") - } - } - if c.vers == VersionTLS13 { - payloadBytes-- // encrypted ContentType - } - - // Allow packet growth in arithmetic progression up to max. - pkt := c.packetsSent - c.packetsSent++ - if pkt > 1000 { - return maxPlaintext // avoid overflow in multiply below - } - - n := payloadBytes * int(pkt+1) - if n > maxPlaintext { - n = maxPlaintext - } - return n -} - -func (c *Conn) write(data []byte) (int, error) { - if c.buffering { - c.sendBuf = append(c.sendBuf, data...) - return len(data), nil - } - - n, err := c.conn.Write(data) - c.bytesSent += int64(n) - return n, err -} - -func (c *Conn) flush() (int, error) { - if len(c.sendBuf) == 0 { - return 0, nil - } - - n, err := c.conn.Write(c.sendBuf) - c.bytesSent += int64(n) - c.sendBuf = nil - c.buffering = false - return n, err -} - -// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. -var outBufPool = sync.Pool{ - New: func() any { - return new([]byte) - }, -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { - if c.quic != nil { - if typ != recordTypeHandshake { - return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport") - } - c.quicWriteCryptoData(c.out.level, data) - if !c.buffering { - if _, err := c.flush(); err != nil { - return 0, err - } - } - return len(data), nil - } - - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - for len(data) > 0 { - m := len(data) - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = VersionTLS10 - } else if vers == VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - } - - if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(err.(alert)) - } - } - - return n, nil -} - -// writeHandshakeRecord writes a handshake message to the connection and updates -// the record layer state. If transcript is non-nil the marshalled message is -// written to it. -func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { - c.out.Lock() - defer c.out.Unlock() - - data, err := msg.marshal() - if err != nil { - return 0, err - } - if transcript != nil { - transcript.Write(data) - } - - return c.writeRecordLocked(recordTypeHandshake, data) -} - -// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and -// updates the record layer state. -func (c *Conn) writeChangeCipherRecord() error { - c.out.Lock() - defer c.out.Unlock() - _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) - return err -} - -// readHandshakeBytes reads handshake data until c.hand contains at least n bytes. -func (c *Conn) readHandshakeBytes(n int) error { - if c.quic != nil { - return c.quicReadHandshakeBytes(n) - } - for c.hand.Len() < n { - if err := c.readRecord(); err != nil { - return err - } - } - return nil -} - -// readHandshake reads the next handshake message from -// the record layer. If transcript is non-nil, the message -// is written to the passed transcriptHash. -func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { - if err := c.readHandshakeBytes(4); err != nil { - return nil, err - } - data := c.hand.Bytes() - n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if n > maxHandshake { - c.sendAlertLocked(alertInternalError) - return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) - } - if err := c.readHandshakeBytes(4 + n); err != nil { - return nil, err - } - data = c.hand.Next(4 + n) - return c.unmarshalHandshakeMessage(data, transcript) -} - -func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) { - var m handshakeMessage - switch data[0] { - case typeHelloRequest: - m = new(helloRequestMsg) - case typeClientHello: - m = new(clientHelloMsg) - case typeServerHello: - m = new(serverHelloMsg) - case typeNewSessionTicket: - if c.vers == VersionTLS13 { - m = new(newSessionTicketMsgTLS13) - } else { - m = new(newSessionTicketMsg) - } - case typeCertificate: - if c.vers == VersionTLS13 { - m = new(certificateMsgTLS13) - } else { - m = new(certificateMsg) - } - case typeCertificateRequest: - if c.vers == VersionTLS13 { - m = new(certificateRequestMsgTLS13) - } else { - m = &certificateRequestMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - } - case typeCertificateStatus: - m = new(certificateStatusMsg) - case typeServerKeyExchange: - m = new(serverKeyExchangeMsg) - case typeServerHelloDone: - m = new(serverHelloDoneMsg) - case typeClientKeyExchange: - m = new(clientKeyExchangeMsg) - case typeCertificateVerify: - m = &certificateVerifyMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - case typeFinished: - m = new(finishedMsg) - case typeEncryptedExtensions: - m = new(encryptedExtensionsMsg) - case typeEndOfEarlyData: - m = new(endOfEarlyDataMsg) - case typeKeyUpdate: - m = new(keyUpdateMsg) - default: - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - // The handshake message unmarshalers - // expect to be able to keep references to data, - // so pass in a fresh copy that won't be overwritten. - data = append([]byte(nil), data...) - - if !m.unmarshal(data) { - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - if transcript != nil { - transcript.Write(data) - } - - return m, nil -} - -var ( - errShutdown = errors.New("tls: protocol is shutdown") -) - -// Write writes data to the connection. -// -// As Write calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Write is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Write(b []byte) (int, error) { - // interlock with Close below - for { - x := c.activeCall.Load() - if x&1 != 0 { - return 0, net.ErrClosed - } - if c.activeCall.CompareAndSwap(x, x+2) { - break - } - } - defer c.activeCall.Add(-2) - - if err := c.Handshake(); err != nil { - return 0, err - } - - c.out.Lock() - defer c.out.Unlock() - - if err := c.out.err; err != nil { - return 0, err - } - - if !c.isHandshakeComplete.Load() { - return 0, alertInternalError - } - - if c.closeNotifySent { - return 0, errShutdown - } - - // TLS 1.0 is susceptible to a chosen-plaintext - // attack when using block mode ciphers due to predictable IVs. - // This can be prevented by splitting each Application Data - // record into two records, effectively randomizing the IV. - // - // https://www.openssl.org/~bodo/tls-cbc.txt - // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 - // https://www.imperialviolet.org/2012/01/15/beastfollowup.html - - var m int - if len(b) > 1 && c.vers == VersionTLS10 { - if _, ok := c.out.cipher.(cipher.BlockMode); ok { - n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) - if err != nil { - return n, c.out.setErrorLocked(err) - } - m, b = 1, b[1:] - } - } - - n, err := c.writeRecordLocked(recordTypeApplicationData, b) - return n + m, c.out.setErrorLocked(err) -} - -// handleRenegotiation processes a HelloRequest handshake message. -func (c *Conn) handleRenegotiation() error { - if c.vers == VersionTLS13 { - return errors.New("tls: internal error: unexpected renegotiation") - } - - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - helloReq, ok := msg.(*helloRequestMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(helloReq, msg) - } - - if !c.isClient { - return c.sendAlert(alertNoRenegotiation) - } - - switch c.config.Renegotiation { - case RenegotiateNever: - return c.sendAlert(alertNoRenegotiation) - case RenegotiateOnceAsClient: - if c.handshakes > 1 { - return c.sendAlert(alertNoRenegotiation) - } - case RenegotiateFreelyAsClient: - // Ok. - default: - c.sendAlert(alertInternalError) - return errors.New("tls: unknown Renegotiation value") - } - - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - c.isHandshakeComplete.Store(false) - if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { - c.handshakes++ - } - return c.handshakeErr -} - -// handlePostHandshakeMessage processes a handshake message arrived after the -// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. -func (c *Conn) handlePostHandshakeMessage() error { - if c.vers != VersionTLS13 { - return c.handleRenegotiation() - } - - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) - } - - switch msg := msg.(type) { - case *newSessionTicketMsgTLS13: - return c.handleNewSessionTicket(msg) - case *keyUpdateMsg: - return c.handleKeyUpdate(msg) - } - // The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest - // as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an - // unexpected_message alert here doesn't provide it with enough information to distinguish - // this condition from other unexpected messages. This is probably fine. - c.sendAlert(alertUnexpectedMessage) - return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) -} - -func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { - if c.quic != nil { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: received unexpected key update message")) - } - - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil { - return c.in.setErrorLocked(c.sendAlert(alertInternalError)) - } - - newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) - - if keyUpdate.updateRequested { - c.out.Lock() - defer c.out.Unlock() - - msg := &keyUpdateMsg{} - msgBytes, err := msg.marshal() - if err != nil { - return err - } - _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) - if err != nil { - // Surface the error at the next write. - c.out.setErrorLocked(err) - return nil - } - - newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) - c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) - } - - return nil -} - -// Read reads data from the connection. -// -// As Read calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Read is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Read(b []byte) (int, error) { - if err := c.Handshake(); err != nil { - return 0, err - } - if len(b) == 0 { - // Put this after Handshake, in case people were calling - // Read(nil) for the side effect of the Handshake. - return 0, nil - } - - c.in.Lock() - defer c.in.Unlock() - - for c.input.Len() == 0 { - if err := c.readRecord(); err != nil { - return 0, err - } - for c.hand.Len() > 0 { - if err := c.handlePostHandshakeMessage(); err != nil { - return 0, err - } - } - } - - n, _ := c.input.Read(b) - - // If a close-notify alert is waiting, read it so that we can return (n, - // EOF) instead of (n, nil), to signal to the HTTP response reading - // goroutine that the connection is now closed. This eliminates a race - // where the HTTP response reading goroutine would otherwise not observe - // the EOF until its next read, by which time a client goroutine might - // have already tried to reuse the HTTP connection for a new request. - // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 - if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && - recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { - if err := c.readRecord(); err != nil { - return n, err // will be io.EOF on closeNotify - } - } - - return n, nil -} - -// Close closes the connection. -func (c *Conn) Close() error { - // Interlock with Conn.Write above. - var x int32 - for { - x = c.activeCall.Load() - if x&1 != 0 { - return net.ErrClosed - } - if c.activeCall.CompareAndSwap(x, x|1) { - break - } - } - if x != 0 { - // io.Writer and io.Closer should not be used concurrently. - // If Close is called while a Write is currently in-flight, - // interpret that as a sign that this Close is really just - // being used to break the Write and/or clean up resources and - // avoid sending the alertCloseNotify, which may block - // waiting on handshakeMutex or the c.out mutex. - return c.conn.Close() - } - - var alertErr error - if c.isHandshakeComplete.Load() { - if err := c.closeNotify(); err != nil { - alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) - } - } - - if err := c.conn.Close(); err != nil { - return err - } - return alertErr -} - -var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") - -// CloseWrite shuts down the writing side of the connection. It should only be -// called once the handshake has completed and does not call CloseWrite on the -// underlying connection. Most callers should just use Close. -func (c *Conn) CloseWrite() error { - if !c.isHandshakeComplete.Load() { - return errEarlyCloseWrite - } - - return c.closeNotify() -} - -func (c *Conn) closeNotify() error { - c.out.Lock() - defer c.out.Unlock() - - if !c.closeNotifySent { - // Set a Write Deadline to prevent possibly blocking forever. - c.SetWriteDeadline(time.Now().Add(time.Second * 5)) - c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) - c.closeNotifySent = true - // Any subsequent writes will fail. - c.SetWriteDeadline(time.Now()) - } - return c.closeNotifyErr -} - -// Handshake runs the client or server handshake -// protocol if it has not yet been run. -// -// Most uses of this package need not call Handshake explicitly: the -// first Read or Write will call it automatically. -// -// For control over canceling or setting a timeout on a handshake, use -// HandshakeContext or the Dialer's DialContext method instead. -func (c *Conn) Handshake() error { - return c.HandshakeContext(context.Background()) -} - -// HandshakeContext runs the client or server handshake -// protocol if it has not yet been run. -// -// The provided Context must be non-nil. If the context is canceled before -// the handshake is complete, the handshake is interrupted and an error is returned. -// Once the handshake has completed, cancellation of the context will not affect the -// connection. -// -// Most uses of this package need not call HandshakeContext explicitly: the -// first Read or Write will call it automatically. -func (c *Conn) HandshakeContext(ctx context.Context) error { - // Delegate to unexported method for named return - // without confusing documented signature. - return c.handshakeContext(ctx) -} - -func (c *Conn) handshakeContext(ctx context.Context) (ret error) { - // Fast sync/atomic-based exit if there is no handshake in flight and the - // last one succeeded without an error. Avoids the expensive context setup - // and mutex for most Read and Write calls. - if c.isHandshakeComplete.Load() { - return nil - } - - handshakeCtx, cancel := context.WithCancel(ctx) - // Note: defer this before starting the "interrupter" goroutine - // so that we can tell the difference between the input being canceled and - // this cancellation. In the former case, we need to close the connection. - defer cancel() - - if c.quic != nil { - c.quic.cancelc = handshakeCtx.Done() - c.quic.cancel = cancel - } else if ctx.Done() != nil { - // Start the "interrupter" goroutine, if this context might be canceled. - // (The background context cannot). - // - // The interrupter goroutine waits for the input context to be done and - // closes the connection if this happens before the function returns. - done := make(chan struct{}) - interruptRes := make(chan error, 1) - defer func() { - close(done) - if ctxErr := <-interruptRes; ctxErr != nil { - // Return context error to user. - ret = ctxErr - } - }() - go func() { - select { - case <-handshakeCtx.Done(): - // Close the connection, discarding the error - _ = c.conn.Close() - interruptRes <- handshakeCtx.Err() - case <-done: - interruptRes <- nil - } - }() - } - - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if err := c.handshakeErr; err != nil { - return err - } - if c.isHandshakeComplete.Load() { - return nil - } - - c.in.Lock() - defer c.in.Unlock() - - c.handshakeErr = c.handshakeFn(handshakeCtx) - if c.handshakeErr == nil { - c.handshakes++ - } else { - // If an error occurred during the handshake try to flush the - // alert that might be left in the buffer. - c.flush() - } - - if c.handshakeErr == nil && !c.isHandshakeComplete.Load() { - c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") - } - if c.handshakeErr != nil && c.isHandshakeComplete.Load() { - panic("tls: internal error: handshake returned an error but is marked successful") - } - - if c.quic != nil { - if c.handshakeErr == nil { - c.quicHandshakeComplete() - // Provide the 1-RTT read secret now that the handshake is complete. - // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing - // the handshake (RFC 9001, Section 5.7). - c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret) - } else { - var a alert - c.out.Lock() - if !errors.As(c.out.err, &a) { - a = alertInternalError - } - c.out.Unlock() - // Return an error which wraps both the handshake error and - // any alert error we may have sent, or alertInternalError - // if we didn't send an alert. - // Truncate the text of the alert to 0 characters. - c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a)) - } - close(c.quic.blockedc) - close(c.quic.signalc) - } - - return c.handshakeErr -} - -// ConnectionState returns basic TLS details about the connection. -func (c *Conn) ConnectionState() ConnectionState { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return c.connectionStateLocked() -} - -func (c *Conn) connectionStateLocked() ConnectionState { - var state connectionState - state.HandshakeComplete = c.isHandshakeComplete.Load() - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = true - state.ServerName = c.serverName - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial - } else { - state.ekm = c.ekm - } - return toConnectionState(state) -} - -// OCSPResponse returns the stapled OCSP response from the TLS server, if -// any. (Only valid for client connections.) -func (c *Conn) OCSPResponse() []byte { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - return c.ocspResponse -} - -// VerifyHostname checks that the peer certificate chain is valid for -// connecting to host. If so, it returns nil; if not, it returns an error -// describing the problem. -func (c *Conn) VerifyHostname(host string) error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - if !c.isClient { - return errors.New("tls: VerifyHostname called on TLS server connection") - } - if !c.isHandshakeComplete.Load() { - return errors.New("tls: handshake has not yet been performed") - } - if len(c.verifiedChains) == 0 { - return errors.New("tls: handshake did not verify certificate chain") - } - return c.peerCertificates[0].VerifyHostname(host) -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/handshake_client.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_client.go deleted file mode 100644 index a5fdd54ca..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/handshake_client.go +++ /dev/null @@ -1,1130 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdh" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "strings" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -const clientSessionStateVersion = 1 - -type clientHandshakeState struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - suite *cipherSuite - finishedHash finishedHash - masterSecret []byte - session *clientSessionState -} - -var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme - -func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) { - config := c.config - if len(config.ServerName) == 0 && !config.InsecureSkipVerify { - return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") - } - - nextProtosLength := 0 - for _, proto := range config.NextProtos { - if l := len(proto); l == 0 || l > 255 { - return nil, nil, errors.New("tls: invalid NextProtos value") - } else { - nextProtosLength += 1 + l - } - } - if nextProtosLength > 0xffff { - return nil, nil, errors.New("tls: NextProtos values too large") - } - - supportedVersions := config.supportedVersions(roleClient) - if len(supportedVersions) == 0 { - return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") - } - - clientHelloVersion := config.maxSupportedVersion(roleClient) - // The version at the beginning of the ClientHello was capped at TLS 1.2 - // for compatibility reasons. The supported_versions extension is used - // to negotiate versions now. See RFC 8446, Section 4.2.1. - if clientHelloVersion > VersionTLS12 { - clientHelloVersion = VersionTLS12 - } - - hello := &clientHelloMsg{ - vers: clientHelloVersion, - compressionMethods: []uint8{compressionNone}, - random: make([]byte, 32), - ocspStapling: true, - scts: true, - serverName: hostnameInSNI(config.ServerName), - supportedCurves: config.curvePreferences(), - supportedPoints: []uint8{pointFormatUncompressed}, - secureRenegotiationSupported: true, - alpnProtocols: config.NextProtos, - supportedVersions: supportedVersions, - } - - if c.handshakes > 0 { - hello.secureRenegotiation = c.clientFinished[:] - } - - preferenceOrder := cipherSuitesPreferenceOrder - if !hasAESGCMHardwareSupport { - preferenceOrder = cipherSuitesPreferenceOrderNoAES - } - configCipherSuites := config.cipherSuites() - hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) - - for _, suiteId := range preferenceOrder { - suite := mutualCipherSuite(configCipherSuites, suiteId) - if suite == nil { - continue - } - // Don't advertise TLS 1.2-only cipher suites unless - // we're attempting TLS 1.2. - if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { - continue - } - hello.cipherSuites = append(hello.cipherSuites, suiteId) - } - - _, err := io.ReadFull(config.rand(), hello.random) - if err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - - // A random session ID is used to detect when the server accepted a ticket - // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as - // a compatibility measure (see RFC 8446, Section 4.1.2). - // - // The session ID is not set for QUIC connections (see RFC 9001, Section 8.4). - if c.quic == nil { - hello.sessionId = make([]byte, 32) - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - } - - if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - } - if testingOnlyForceClientHelloSignatureAlgorithms != nil { - hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms - } - - var key *ecdh.PrivateKey - if hello.supportedVersions[0] == VersionTLS13 { - if len(hello.supportedVersions) == 1 { - hello.cipherSuites = hello.cipherSuites[:0] - } - if hasAESGCMHardwareSupport { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) - } else { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) - } - - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err = generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} - } - - if c.quic != nil { - p, err := c.quicGetTransportParameters() - if err != nil { - return nil, nil, err - } - if p == nil { - p = []byte{} - } - hello.quicTransportParameters = p - } - - return hello, key, nil -} - -func (c *Conn) clientHandshake(ctx context.Context) (err error) { - if c.config == nil { - c.config = fromConfig(defaultConfig()) - } - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, ecdheKey, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) - if err != nil { - return err - } - if cacheKey != "" && session != nil { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - }() - } - - if _, err := c.writeHandshakeRecord(hello, nil); err != nil { - return err - } - - if hello.earlyData { - suite := cipherSuiteTLS13ByID(session.cipherSuite) - transcript := suite.hash.New() - if err := transcriptMsg(hello, transcript); err != nil { - return err - } - earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript) - c.quicSetWriteSecret(QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret) - } - - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion(roleClient) - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - - if c.vers == VersionTLS13 { - hs := &clientHandshakeStateTLS13{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - ecdheKey: ecdheKey, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - } - - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } - - hs := &clientHandshakeState{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: session, - } - - if err := hs.handshake(); err != nil { - return err - } - - // If we had a successful handshake and hs.session is different from - // the one already cached - cache a new one. - if cacheKey != "" && hs.session != nil && session != hs.session { - c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) - } - - return nil -} - -// extract the app data saved in the session.nonce, -// and set the session.nonce to the actual nonce value -func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { - s := cryptobyte.String(session.nonce) - var version uint16 - if !s.ReadUint16(&version) { - return 0, nil, false - } - if version != clientSessionStateVersion { - return 0, nil, false - } - var maxEarlyData uint32 - if !s.ReadUint32(&maxEarlyData) { - return 0, nil, false - } - var appData []byte - if !readUint16LengthPrefixed(&s, &appData) { - return 0, nil, false - } - var nonce []byte - if !readUint16LengthPrefixed(&s, &nonce) { - return 0, nil, false - } - session.nonce = nonce - return maxEarlyData, appData, true -} - -func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *clientSessionState, earlySecret, binderKey []byte, err error) { - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil, nil - } - - hello.ticketSupported = true - - if hello.supportedVersions[0] == VersionTLS13 { - // Require DHE on resumption as it guarantees forward secrecy against - // compromise of the session ticket key. See RFC 8446, Section 4.2.9. - hello.pskModes = []uint8{pskModeDHE} - } - - // Session resumption is not allowed if renegotiating because - // renegotiation is primarily used to allow a client to send a client - // certificate, which would be skipped if session resumption occurred. - if c.handshakes != 0 { - return "", nil, nil, nil, nil - } - - // Try to resume a previously negotiated TLS session, if available. - cacheKey = c.clientSessionCacheKey() - if cacheKey == "" { - return "", nil, nil, nil, nil - } - sess, ok := c.config.ClientSessionCache.Get(cacheKey) - if !ok || sess == nil { - return cacheKey, nil, nil, nil, nil - } - session = fromClientSessionState(sess) - - var appData []byte - var maxEarlyData uint32 - if session.vers == VersionTLS13 { - var ok bool - maxEarlyData, appData, ok = c.decodeSessionState(session) - if !ok { // delete it, if parsing failed - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil - } - } - - // Check that version used for the previous session is still valid. - versOk := false - for _, v := range hello.supportedVersions { - if v == session.vers { - versOk = true - break - } - } - if !versOk { - return cacheKey, nil, nil, nil, nil - } - - // Check that the cached server certificate is not expired, and that it's - // valid for the ServerName. This should be ensured by the cache key, but - // protect the application from a faulty ClientSessionCache implementation. - if !c.config.InsecureSkipVerify { - if len(session.verifiedChains) == 0 { - // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil, nil - } - serverCert := session.serverCertificates[0] - if c.config.time().After(serverCert.NotAfter) { - // Expired certificate, delete the entry. - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil - } - if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil, nil - } - } - - if session.vers != VersionTLS13 { - // In TLS 1.2 the cipher suite must match the resumed session. Ensure we - // are still offering it. - if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil, nil - } - - hello.sessionTicket = session.sessionTicket - return - } - - // Check that the session ticket is not expired. - if c.config.time().After(session.useBy) { - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil, nil - } - - // In TLS 1.3 the KDF hash must match the resumed session. Ensure we - // offer at least one cipher suite with that hash. - cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) - if cipherSuite == nil { - return cacheKey, nil, nil, nil, nil - } - cipherSuiteOk := false - for _, offeredID := range hello.cipherSuites { - offeredSuite := cipherSuiteTLS13ByID(offeredID) - if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return cacheKey, nil, nil, nil, nil - } - - if c.quic != nil && maxEarlyData > 0 { - // For 0-RTT, the cipher suite has to match exactly. - if mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil { - hello.earlyData = true - } - } - - // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. - ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) - identity := pskIdentity{ - label: session.sessionTicket, - obfuscatedTicketAge: ticketAge + session.ageAdd, - } - hello.pskIdentities = []pskIdentity{identity} - hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} - - // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. - psk := cipherSuite.expandLabel(session.masterSecret, "resumption", - session.nonce, cipherSuite.hash.Size()) - earlySecret = cipherSuite.extract(psk, nil) - binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) - transcript := cipherSuite.hash.New() - helloBytes, err := hello.marshalWithoutBinders() - if err != nil { - return "", nil, nil, nil, err - } - transcript.Write(helloBytes) - pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - if err := hello.updateBinders(pskBinders); err != nil { - return "", nil, nil, nil, err - } - - if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { - c.extraConfig.SetAppDataFromSessionState(appData) - } - return -} - -func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { - peerVersion := serverHello.vers - if serverHello.supportedVersion != 0 { - peerVersion = serverHello.supportedVersion - } - - vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion}) - if !ok { - c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) - } - - c.vers = vers - c.haveVers = true - c.in.version = vers - c.out.version = vers - - return nil -} - -// Does the handshake, either a full one or resumes old session. Requires hs.c, -// hs.hello, hs.serverHello, and, optionally, hs.session to be set. -func (hs *clientHandshakeState) handshake() error { - c := hs.c - - isResume, err := hs.processServerHello() - if err != nil { - return err - } - - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - - // No signatures of the handshake are needed in a resumption. - // Otherwise, in a full handshake, if we don't have any certificates - // configured then we will never send a CertificateVerify message and - // thus no signatures are needed in that case either. - if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { - hs.finishedHash.discardHandshakeBuffer() - } - - if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { - return err - } - if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { - return err - } - - c.buffering = true - c.didResume = isResume - if isResume { - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = false - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } else { - if err := hs.doFullHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.clientFinishedIsFirst = true - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - c.isHandshakeComplete.Store(true) - - return nil -} - -func (hs *clientHandshakeState) pickCipherSuite() error { - if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { - hs.c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server chose an unconfigured cipher suite") - } - - hs.c.cipherSuite = hs.suite.id - return nil -} - -func (hs *clientHandshakeState) doFullHandshake() error { - c := hs.c - - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - certMsg, ok := msg.(*certificateMsg) - if !ok || len(certMsg.certificates) == 0 { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - - cs, ok := msg.(*certificateStatusMsg) - if ok { - // RFC4366 on Certificate Status Request: - // The server MAY return a "certificate_status" message. - - if !hs.serverHello.ocspStapling { - // If a server returns a "CertificateStatus" message, then the - // server MUST have included an extension of type "status_request" - // with empty "extension_data" in the extended server hello. - - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received unexpected CertificateStatus message") - } - - c.ocspResponse = cs.response - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - - if c.handshakes == 0 { - // If this is the first handshake on a connection, process and - // (optionally) verify the server's certificates. - if err := c.verifyServerCertificate(certMsg.certificates); err != nil { - return err - } - } else { - // This is a renegotiation handshake. We require that the - // server's identity (i.e. leaf certificate) is unchanged and - // thus any previous trust decision is still valid. - // - // See https://mitls.org/pages/attacks/3SHAKE for the - // motivation behind this requirement. - if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: server's identity changed during renegotiation") - } - } - - keyAgreement := hs.suite.ka(c.vers) - - skx, ok := msg.(*serverKeyExchangeMsg) - if ok { - err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) - if err != nil { - c.sendAlert(alertUnexpectedMessage) - return err - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - - var chainToSend *Certificate - var certRequested bool - certReq, ok := msg.(*certificateRequestMsg) - if ok { - certRequested = true - - cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) - if chainToSend, err = c.getClientCertificate(cri); err != nil { - c.sendAlert(alertInternalError) - return err - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - - shd, ok := msg.(*serverHelloDoneMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(shd, msg) - } - - // If the server requested a certificate then we have to send a - // Certificate message, even if it's empty because we don't have a - // certificate to send. - if certRequested { - certMsg = new(certificateMsg) - certMsg.certificates = chainToSend.Certificate - if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { - return err - } - } - - preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - if ckx != nil { - if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { - return err - } - } - - if chainToSend != nil && len(chainToSend.Certificate) > 0 { - certVerify := &certificateVerifyMsg{} - - key, ok := chainToSend.PrivateKey.(crypto.Signer) - if !ok { - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - certVerify.hasSignatureAlgorithm = true - certVerify.signatureAlgorithm = signatureAlgorithm - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { - return err - } - } - - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to write to key log: " + err.Error()) - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *clientHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - var clientCipher, serverCipher any - var clientHash, serverHash hash.Hash - if hs.suite.cipher != nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) - c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) - return nil -} - -func (hs *clientHandshakeState) serverResumedSession() bool { - // If the server responded with the same sessionId then it means the - // sessionTicket is being used to resume a TLS session. - return hs.session != nil && hs.hello.sessionId != nil && - bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) -} - -func (hs *clientHandshakeState) processServerHello() (bool, error) { - c := hs.c - - if err := hs.pickCipherSuite(); err != nil { - return false, err - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertUnexpectedMessage) - return false, errors.New("tls: server selected unsupported compression format") - } - - if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { - c.secureRenegotiation = true - if len(hs.serverHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: initial handshake had non-empty renegotiation extension") - } - } - - if c.handshakes > 0 && c.secureRenegotiation { - var expectedSecureRenegotiation [24]byte - copy(expectedSecureRenegotiation[:], c.clientFinished[:]) - copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) - if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: incorrect renegotiation extension contents") - } - } - - if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil { - c.sendAlert(alertUnsupportedExtension) - return false, err - } - c.clientProtocol = hs.serverHello.alpnProtocol - - c.scts = hs.serverHello.scts - - if !hs.serverResumedSession() { - return false, nil - } - - if hs.session.vers != c.vers { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different version") - } - - if hs.session.cipherSuite != hs.suite.id { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different cipher suite") - } - - // Restore masterSecret, peerCerts, and ocspResponse from previous state - hs.masterSecret = hs.session.masterSecret - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - // Let the ServerHello SCTs override the session SCTs from the original - // connection, if any are provided - if len(c.scts) == 0 && len(hs.session.scts) != 0 { - c.scts = hs.session.scts - } - - return true, nil -} - -// checkALPN ensure that the server's choice of ALPN protocol is compatible with -// the protocols that we advertised in the Client Hello. -func checkALPN(clientProtos []string, serverProto string, quic bool) error { - if serverProto == "" { - if quic && len(clientProtos) > 0 { - // RFC 9001, Section 8.1 - return errors.New("tls: server did not select an ALPN protocol") - } - return nil - } - if len(clientProtos) == 0 { - return errors.New("tls: server advertised unrequested ALPN extension") - } - for _, proto := range clientProtos { - if proto == serverProto { - return nil - } - } - return errors.New("tls: server selected unadvertised ALPN protocol") -} - -func (hs *clientHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - // finishedMsg is included in the transcript, but not until after we - // check the client version, since the state before this message was - // sent is used during verification. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - serverFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverFinished, msg) - } - - verify := hs.finishedHash.serverSum(hs.masterSecret) - if len(verify) != len(serverFinished.verifyData) || - subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server's Finished message was incorrect") - } - - if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { - return err - } - - copy(out, verify) - return nil -} - -func (hs *clientHandshakeState) readSessionTicket() error { - if !hs.serverHello.ticketSupported { - return nil - } - - c := hs.c - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - sessionTicketMsg, ok := msg.(*newSessionTicketMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(sessionTicketMsg, msg) - } - - hs.session = &clientSessionState{ - sessionTicket: sessionTicketMsg.ticket, - vers: c.vers, - cipherSuite: hs.suite.id, - masterSecret: hs.masterSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - return nil -} - -func (hs *clientHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if err := c.writeChangeCipherRecord(); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { - return err - } - copy(out, finished.verifyData) - return nil -} - -// maxRSAKeySize is the maximum RSA key size in bits that we are willing -// to verify the signatures of during a TLS handshake. -const maxRSAKeySize = 8192 - -// verifyServerCertificate parses and verifies the provided chain, setting -// c.verifiedChains and c.peerCertificates or sending the appropriate alert. -func (c *Conn) verifyServerCertificate(certificates [][]byte) error { - activeHandles := make([]*activeCert, len(certificates)) - certs := make([]*x509.Certificate, len(certificates)) - for i, asn1Data := range certificates { - cert, err := clientCertCache.newCert(asn1Data) - if err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse certificate from server: " + err.Error()) - } - if cert.cert.PublicKeyAlgorithm == x509.RSA && cert.cert.PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize { - c.sendAlert(alertBadCertificate) - return fmt.Errorf("tls: server sent certificate containing RSA key larger than %d bits", maxRSAKeySize) - } - activeHandles[i] = cert - certs[i] = cert.cert - } - - if !c.config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: c.config.RootCAs, - CurrentTime: c.config.time(), - DNSName: c.config.ServerName, - Intermediates: x509.NewCertPool(), - } - - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - var err error - c.verifiedChains, err = certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} - } - } - - switch certs[0].PublicKey.(type) { - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: - break - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) - } - - c.activeCertHandles = activeHandles - c.peerCertificates = certs - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS -// <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { - cri := &certificateRequestInfo{ - AcceptableCAs: certReq.certificateAuthorities, - Version: vers, - ctx: ctx, - } - - var rsaAvail, ecAvail bool - for _, certType := range certReq.certificateTypes { - switch certType { - case certTypeRSASign: - rsaAvail = true - case certTypeECDSASign: - ecAvail = true - } - } - - if !certReq.hasSignatureAlgorithm { - // Prior to TLS 1.2, signature schemes did not exist. In this case we - // make up a list based on the acceptable certificate types, to help - // GetClientCertificate and SupportsCertificate select the right certificate. - // The hash part of the SignatureScheme is a lie here, because - // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. - switch { - case rsaAvail && ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case rsaAvail: - cri.SignatureSchemes = []SignatureScheme{ - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - } - } - return toCertificateRequestInfo(cri) - } - - // Filter the signature schemes based on the certificate types. - // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). - cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) - for _, sigScheme := range certReq.supportedSignatureAlgorithms { - sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) - if err != nil { - continue - } - switch sigType { - case signatureECDSA, signatureEd25519: - if ecAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - case signatureRSAPSS, signaturePKCS1v15: - if rsaAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - } - } - - return toCertificateRequestInfo(cri) -} - -func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { - if c.config.GetClientCertificate != nil { - return c.config.GetClientCertificate(cri) - } - - for _, chain := range c.config.Certificates { - if err := cri.SupportsCertificate(&chain); err != nil { - continue - } - return &chain, nil - } - - // No acceptable certificate found. Don't send a certificate. - return new(Certificate), nil -} - -// clientSessionCacheKey returns a key used to cache sessionTickets that could -// be used to resume previously negotiated TLS sessions with a server. -func (c *Conn) clientSessionCacheKey() string { - if len(c.config.ServerName) > 0 { - return c.config.ServerName - } - if c.conn != nil { - return c.conn.RemoteAddr().String() - } - return "" -} - -// hostnameInSNI converts name into an appropriate hostname for SNI. -// Literal IP addresses and absolute FQDNs are not permitted as SNI values. -// See RFC 6066, Section 3. -func hostnameInSNI(name string) string { - host := name - if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { - host = host[1 : len(host)-1] - } - if i := strings.LastIndex(host, "%"); i > 0 { - host = host[:i] - } - if net.ParseIP(host) != nil { - return "" - } - for len(name) > 0 && name[len(name)-1] == '.' { - name = name[:len(name)-1] - } - return name -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go deleted file mode 100644 index e9d0a533e..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdh" - "crypto/hmac" - "crypto/rsa" - "encoding/binary" - "errors" - "hash" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheKey *ecdh.PrivateKey - - session *clientSessionState - earlySecret []byte - binderKey []byte - - certReq *certificateRequestMsgTLS13 - usingPSK bool - sentDummyCCS bool - suite *cipherSuiteTLS13 - transcript hash.Hash - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 -} - -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and, -// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. -func (hs *clientHandshakeStateTLS13) handshake() error { - c := hs.c - - if needFIPS() { - return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") - } - - // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, - // sections 4.1.2 and 4.1.3. - if c.handshakes > 0 { - c.sendAlert(alertProtocolVersion) - return errors.New("tls: server selected TLS 1.3 in a renegotiation") - } - - // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 { - return c.sendAlert(alertInternalError) - } - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - hs.transcript = hs.suite.hash.New() - - if err := transcriptMsg(hs.hello, hs.transcript); err != nil { - return err - } - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.processHelloRetryRequest(); err != nil { - return err - } - } - - if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { - return err - } - - c.buffering = true - if err := hs.processServerHello(); err != nil { - return err - } - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.establishHandshakeKeys(); err != nil { - return err - } - if err := hs.readServerParameters(); err != nil { - return err - } - if err := hs.readServerCertificate(); err != nil { - return err - } - if err := hs.readServerFinished(); err != nil { - return err - } - if err := hs.sendClientCertificate(); err != nil { - return err - } - if err := hs.sendClientFinished(); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - - c.isHandshakeComplete.Store(true) - - return nil -} - -// checkServerHelloOrHRR does validity checks that apply to both ServerHello and -// HelloRetryRequest messages. It sets hs.suite. -func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { - c := hs.c - - if hs.serverHello.supportedVersion == 0 { - c.sendAlert(alertMissingExtension) - return errors.New("tls: server selected TLS 1.3 using the legacy version field") - } - - if hs.serverHello.supportedVersion != VersionTLS13 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid version after a HelloRetryRequest") - } - - if hs.serverHello.vers != VersionTLS12 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an incorrect legacy version") - } - - if hs.serverHello.ocspStapling || - hs.serverHello.ticketSupported || - hs.serverHello.secureRenegotiationSupported || - len(hs.serverHello.secureRenegotiation) != 0 || - len(hs.serverHello.alpnProtocol) != 0 || - len(hs.serverHello.scts) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") - } - - if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not echo the legacy session ID") - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported compression format") - } - - selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) - if hs.suite != nil && selectedSuite != hs.suite { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server changed cipher suite after a HelloRetryRequest") - } - if selectedSuite == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server chose an unconfigured cipher suite") - } - hs.suite = selectedSuite - c.cipherSuite = hs.suite.id - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.c.quic != nil { - return nil - } - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - return hs.c.writeChangeCipherRecord() -} - -// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and -// resends hs.hello, and reads the new ServerHello into hs.serverHello. -func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. (The idea is that the server might offload transcript - // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { - return err - } - - // The only HelloRetryRequest extensions we support are key_share and - // cookie, and clients must abort the handshake if the HRR would not result - // in any change in the ClientHello. - if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest message") - } - - if hs.serverHello.cookie != nil { - hs.hello.cookie = hs.serverHello.cookie - } - - if hs.serverHello.serverShare.group != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received malformed key_share extension") - } - - // If the server sent a key_share extension selecting a group, ensure it's - // a group we advertised but did not send a key share for, and send a key - // share for it this time. - if curveID := hs.serverHello.selectedGroup; curveID != 0 { - curveOK := false - for _, id := range hs.hello.supportedCurves { - if id == curveID { - curveOK = true - break - } - } - if !curveOK { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") - } - if _, ok := curveForCurveID(curveID); !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err := generateECDHEKey(c.config.rand(), curveID) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.ecdheKey = key - hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} - } - - hs.hello.raw = nil - if len(hs.hello.pskIdentities) > 0 { - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash == hs.suite.hash { - // Update binders and obfuscated_ticket_age. - ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) - hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd - - transcript := hs.suite.hash.New() - transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - transcript.Write(chHash) - if err := transcriptMsg(hs.serverHello, transcript); err != nil { - return err - } - helloBytes, err := hs.hello.marshalWithoutBinders() - if err != nil { - return err - } - transcript.Write(helloBytes) - pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - if err := hs.hello.updateBinders(pskBinders); err != nil { - return err - } - } else { - // Server selected a cipher suite incompatible with the PSK. - hs.hello.pskIdentities = nil - hs.hello.pskBinders = nil - } - } - - if hs.hello.earlyData { - hs.hello.earlyData = false - c.quicRejectedEarlyData() - } - - if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { - return err - } - - // serverHelloMsg is not included in the transcript - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - hs.serverHello = serverHello - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) processServerHello() error { - c := hs.c - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: server sent two HelloRetryRequest messages") - } - - if len(hs.serverHello.cookie) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a cookie in a normal ServerHello") - } - - if hs.serverHello.selectedGroup != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: malformed key_share extension") - } - - if hs.serverHello.serverShare.group == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not send a key share") - } - if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - - if !hs.serverHello.selectedIdentityPresent { - return nil - } - - if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK") - } - - if len(hs.hello.pskIdentities) != 1 || hs.session == nil { - return c.sendAlert(alertInternalError) - } - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash != hs.suite.hash { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK and cipher suite pair") - } - - hs.usingPSK = true - c.didResume = true - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - c.scts = hs.session.scts - return nil -} - -func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { - c := hs.c - - peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - sharedKey, err := hs.ecdheKey.ECDH(peerKey) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - - earlySecret := hs.earlySecret - if !hs.usingPSK { - earlySecret = hs.suite.extract(nil, nil) - } - - handshakeSecret := hs.suite.extract(sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) - serverSecret := hs.suite.deriveSecret(handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) - - if c.quic != nil { - if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) - } - c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret) - c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret) - } - - err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(handshakeSecret, "derived", nil)) - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerParameters() error { - c := hs.c - - msg, err := c.readHandshake(hs.transcript) - if err != nil { - return err - } - - encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(encryptedExtensions, msg) - } - - if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil { - // RFC 8446 specifies that no_application_protocol is sent by servers, but - // does not specify how clients handle the selection of an incompatible protocol. - // RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol - // in this case. Always sending no_application_protocol seems reasonable. - c.sendAlert(alertNoApplicationProtocol) - return err - } - c.clientProtocol = encryptedExtensions.alpnProtocol - - if c.quic != nil { - if encryptedExtensions.quicTransportParameters == nil { - // RFC 9001 Section 8.2. - c.sendAlert(alertMissingExtension) - return errors.New("tls: server did not send a quic_transport_parameters extension") - } - c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters) - } else { - if encryptedExtensions.quicTransportParameters != nil { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent an unexpected quic_transport_parameters extension") - } - } - - if hs.hello.earlyData && !encryptedExtensions.earlyData { - c.quicRejectedEarlyData() - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerCertificate() error { - c := hs.c - - // Either a PSK or a certificate is always used, but not both. - // See RFC 8446, Section 4.1.1. - if hs.usingPSK { - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - msg, err := c.readHandshake(hs.transcript) - if err != nil { - return err - } - - certReq, ok := msg.(*certificateRequestMsgTLS13) - if ok { - hs.certReq = certReq - - msg, err = c.readHandshake(hs.transcript) - if err != nil { - return err - } - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - if len(certMsg.certificate.Certificate) == 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received empty certificates message") - } - - c.scts = certMsg.certificate.SignedCertificateTimestamps - c.ocspResponse = certMsg.certificate.OCSPStaple - - if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { - return err - } - - // certificateVerifyMsg is included in the transcript, but not until - // after we verify the handshake signature, since the state before - // this message was sent is used. - msg, err = c.readHandshake(nil) - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - - if err := transcriptMsg(certVerify, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerFinished() error { - c := hs.c - - // finishedMsg is included in the transcript, but not until after we - // check the client version, since the state before this message was - // sent is used during verification. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - if !hmac.Equal(expectedMAC, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid server finished hash") - } - - if err := transcriptMsg(finished, hs.transcript); err != nil { - return err - } - - // Derive secrets that take context through the server Finished. - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) - - err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { - c := hs.c - - if hs.certReq == nil { - return nil - } - - cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ - AcceptableCAs: hs.certReq.certificateAuthorities, - SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, - Version: c.vers, - ctx: hs.ctx, - })) - if err != nil { - return err - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *cert - certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - - if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { - return err - } - - // If we sent an empty certificate message, skip the CertificateVerify. - if len(cert.Certificate) == 0 { - return nil - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - - certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) - if err != nil { - // getClientCertificate returned a certificate incompatible with the - // CertificateRequestInfo supported signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { - return err - } - - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) - - if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - } - - if c.quic != nil { - if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) - } - c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret) - } - - return nil -} - -func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { - if !c.isClient { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received new session ticket from a client") - } - - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return nil - } - - // See RFC 8446, Section 4.6.1. - if msg.lifetime == 0 { - return nil - } - lifetime := time.Duration(msg.lifetime) * time.Second - if lifetime > maxSessionTicketLifetime { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: received a session ticket with invalid lifetime") - } - - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil || c.resumptionSecret == nil { - return c.sendAlert(alertInternalError) - } - - // We need to save the max_early_data_size that the server sent us, in order - // to decide if we're going to try 0-RTT with this ticket. - // However, at the same time, the qtls.ClientSessionTicket needs to be equal to - // the tls.ClientSessionTicket, so we can't just add a new field to the struct. - // We therefore abuse the nonce field (which is a byte slice) - nonceWithEarlyData := make([]byte, len(msg.nonce)+4) - binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) - copy(nonceWithEarlyData[4:], msg.nonce) - - var appData []byte - if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { - appData = c.extraConfig.GetAppDataForSessionState() - } - var b cryptobyte.Builder - b.AddUint16(clientSessionStateVersion) // revision - b.AddUint32(msg.maxEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(appData) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(msg.nonce) - }) - - // Save the resumption_master_secret and nonce instead of deriving the PSK - // to do the least amount of work on NewSessionTicket messages before we - // know if the ticket will be used. Forward secrecy of resumed connections - // is guaranteed by the requirement for pskModeDHE. - session := &clientSessionState{ - sessionTicket: msg.label, - vers: c.vers, - cipherSuite: c.cipherSuite, - masterSecret: c.resumptionSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - nonce: b.BytesOrPanic(), - useBy: c.config.time().Add(lifetime), - ageAdd: msg.ageAdd, - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - cacheKey := c.clientSessionCacheKey() - if cacheKey != "" { - c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) - } - - return nil -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go deleted file mode 100644 index 37b012363..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go +++ /dev/null @@ -1,1886 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "errors" - "fmt" - "strings" - - "golang.org/x/crypto/cryptobyte" -) - -// The marshalingFunction type is an adapter to allow the use of ordinary -// functions as cryptobyte.MarshalingValue. -type marshalingFunction func(b *cryptobyte.Builder) error - -func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { - return f(b) -} - -// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If -// the length of the sequence is not the value specified, it produces an error. -func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { - b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { - if len(v) != n { - return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) - } - b.AddBytes(v) - return nil - })) -} - -// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. -func addUint64(b *cryptobyte.Builder, v uint64) { - b.AddUint32(uint32(v >> 32)) - b.AddUint32(uint32(v)) -} - -// readUint64 decodes a big-endian, 64-bit value into out and advances over it. -// It reports whether the read was successful. -func readUint64(s *cryptobyte.String, out *uint64) bool { - var hi, lo uint32 - if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { - return false - } - *out = uint64(hi)<<32 | uint64(lo) - return true -} - -// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) -} - -type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string - ocspStapling bool - supportedCurves []CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocols []string - scts bool - supportedVersions []uint16 - cookie []byte - keyShares []keyShare - earlyData bool - pskModes []uint8 - pskIdentities []pskIdentity - pskBinders [][]byte - quicTransportParameters []byte -} - -func (m *clientHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var exts cryptobyte.Builder - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - exts.AddUint16(extensionServerName) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8(0) // name_type = host_name - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - exts.AddUint16(extensionStatusRequest) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8(1) // status_type = ocsp - exts.AddUint16(0) // empty responder_id_list - exts.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - exts.AddUint16(extensionSupportedCurves) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - exts.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - exts.AddUint16(extensionSupportedPoints) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - exts.AddUint16(extensionSessionTicket) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - exts.AddUint16(extensionSignatureAlgorithms) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - exts.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - exts.AddUint16(extensionSignatureAlgorithmsCert) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - exts.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - exts.AddUint16(extensionRenegotiationInfo) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - exts.AddUint16(extensionALPN) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - exts.AddUint16(extensionSCT) - exts.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - exts.AddUint16(extensionSupportedVersions) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - exts.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - exts.AddUint16(extensionCookie) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - exts.AddUint16(extensionKeyShare) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, ks := range m.keyShares { - exts.AddUint16(uint16(ks.group)) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - exts.AddUint16(extensionEarlyData) - exts.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - exts.AddUint16(extensionPSKModes) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.pskModes) - }) - }) - } - if m.quicTransportParameters != nil { // marshal zero-length parameters when present - // RFC 9001, Section 8.2 - exts.AddUint16(extensionQUICTransportParameters) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.quicTransportParameters) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - exts.AddUint16(extensionPreSharedKey) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(psk.label) - }) - exts.AddUint32(psk.obfuscatedTicketAge) - } - }) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(binder) - }) - } - }) - }) - } - extBytes, err := exts.Bytes() - if err != nil { - return nil, err - } - - var b cryptobyte.Builder - b.AddUint8(typeClientHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, suite := range m.cipherSuites { - b.AddUint16(suite) - } - }) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.compressionMethods) - }) - - if len(extBytes) > 0 { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(extBytes) - }) - } - }) - - m.raw, err = b.Bytes() - return m.raw, err -} - -// marshalWithoutBinders returns the ClientHello through the -// PreSharedKeyExtension.identities field, according to RFC 8446, Section -// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { - bindersLen := 2 // uint16 length prefix - for _, binder := range m.pskBinders { - bindersLen += 1 // uint8 length prefix - bindersLen += len(binder) - } - - fullMessage, err := m.marshal() - if err != nil { - return nil, err - } - return fullMessage[:len(fullMessage)-bindersLen], nil -} - -// updateBinders updates the m.pskBinders field, if necessary updating the -// cached marshaled representation. The supplied binders must have the same -// length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { - if len(pskBinders) != len(m.pskBinders) { - return errors.New("tls: internal error: pskBinders length mismatch") - } - for i := range m.pskBinders { - if len(pskBinders[i]) != len(m.pskBinders[i]) { - return errors.New("tls: internal error: pskBinders length mismatch") - } - } - m.pskBinders = pskBinders - if m.raw != nil { - helloBytes, err := m.marshalWithoutBinders() - if err != nil { - return err - } - lenWithoutBinders := len(helloBytes) - b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - return errors.New("tls: internal error: failed to update binders") - } - } - - return nil -} - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - *m = clientHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) { - return false - } - - var cipherSuites cryptobyte.String - if !s.ReadUint16LengthPrefixed(&cipherSuites) { - return false - } - m.cipherSuites = []uint16{} - m.secureRenegotiationSupported = false - for !cipherSuites.Empty() { - var suite uint16 - if !cipherSuites.ReadUint16(&suite) { - return false - } - if suite == scsvRenegotiation { - m.secureRenegotiationSupported = true - } - m.cipherSuites = append(m.cipherSuites, suite) - } - - if !readUint8LengthPrefixed(&s, &m.compressionMethods) { - return false - } - - if s.Empty() { - // ClientHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - seenExts := make(map[uint16]bool) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - if seenExts[extension] { - return false - } - seenExts[extension] = true - - switch extension { - case extensionServerName: - // RFC 6066, Section 3 - var nameList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { - return false - } - for !nameList.Empty() { - var nameType uint8 - var serverName cryptobyte.String - if !nameList.ReadUint8(&nameType) || - !nameList.ReadUint16LengthPrefixed(&serverName) || - serverName.Empty() { - return false - } - if nameType != 0 { - continue - } - if len(m.serverName) != 0 { - // Multiple names of the same name_type are prohibited. - return false - } - m.serverName = string(serverName) - // An SNI value may not include a trailing dot. - if strings.HasSuffix(m.serverName, ".") { - return false - } - } - case extensionStatusRequest: - // RFC 4366, Section 3.6 - var statusType uint8 - var ignored cryptobyte.String - if !extData.ReadUint8(&statusType) || - !extData.ReadUint16LengthPrefixed(&ignored) || - !extData.ReadUint16LengthPrefixed(&ignored) { - return false - } - m.ocspStapling = statusType == statusTypeOCSP - case extensionSupportedCurves: - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - var curves cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { - return false - } - for !curves.Empty() { - var curve uint16 - if !curves.ReadUint16(&curve) { - return false - } - m.supportedCurves = append(m.supportedCurves, CurveID(curve)) - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - case extensionSessionTicket: - // RFC 5077, Section 3.2 - m.ticketSupported = true - extData.ReadBytes(&m.sessionTicket, len(extData)) - case extensionSignatureAlgorithms: - // RFC 5246, Section 7.4.1.4.1 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - // RFC 8446, Section 4.2.3 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionRenegotiationInfo: - // RFC 5746, Section 3.2 - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - // RFC 7301, Section 3.1 - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - for !protoList.Empty() { - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { - return false - } - m.alpnProtocols = append(m.alpnProtocols, string(proto)) - } - case extensionSCT: - // RFC 6962, Section 3.3.1 - m.scts = true - case extensionSupportedVersions: - // RFC 8446, Section 4.2.1 - var versList cryptobyte.String - if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { - return false - } - for !versList.Empty() { - var vers uint16 - if !versList.ReadUint16(&vers) { - return false - } - m.supportedVersions = append(m.supportedVersions, vers) - } - case extensionCookie: - // RFC 8446, Section 4.2.2 - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // RFC 8446, Section 4.2.8 - var clientShares cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&clientShares) { - return false - } - for !clientShares.Empty() { - var ks keyShare - if !clientShares.ReadUint16((*uint16)(&ks.group)) || - !readUint16LengthPrefixed(&clientShares, &ks.data) || - len(ks.data) == 0 { - return false - } - m.keyShares = append(m.keyShares, ks) - } - case extensionEarlyData: - // RFC 8446, Section 4.2.10 - m.earlyData = true - case extensionPSKModes: - // RFC 8446, Section 4.2.9 - if !readUint8LengthPrefixed(&extData, &m.pskModes) { - return false - } - case extensionQUICTransportParameters: - m.quicTransportParameters = make([]byte, len(extData)) - if !extData.CopyBytes(m.quicTransportParameters) { - return false - } - case extensionPreSharedKey: - // RFC 8446, Section 4.2.11 - if !extensions.Empty() { - return false // pre_shared_key must be the last extension - } - var identities cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { - return false - } - for !identities.Empty() { - var psk pskIdentity - if !readUint16LengthPrefixed(&identities, &psk.label) || - !identities.ReadUint32(&psk.obfuscatedTicketAge) || - len(psk.label) == 0 { - return false - } - m.pskIdentities = append(m.pskIdentities, psk) - } - var binders cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { - return false - } - for !binders.Empty() { - var binder []byte - if !readUint8LengthPrefixed(&binders, &binder) || - len(binder) == 0 { - return false - } - m.pskBinders = append(m.pskBinders, binder) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type serverHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuite uint16 - compressionMethod uint8 - ocspStapling bool - ticketSupported bool - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocol string - scts [][]byte - supportedVersion uint16 - serverShare keyShare - selectedIdentityPresent bool - selectedIdentity uint16 - supportedPoints []uint8 - - // HelloRetryRequest extensions - cookie []byte - selectedGroup CurveID -} - -func (m *serverHelloMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var exts cryptobyte.Builder - if m.ocspStapling { - exts.AddUint16(extensionStatusRequest) - exts.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - exts.AddUint16(extensionSessionTicket) - exts.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - exts.AddUint16(extensionRenegotiationInfo) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - exts.AddUint16(extensionALPN) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - exts.AddUint16(extensionSCT) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - for _, sct := range m.scts { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - exts.AddUint16(extensionSupportedVersions) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - exts.AddUint16(extensionKeyShare) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(uint16(m.serverShare.group)) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - exts.AddUint16(extensionPreSharedKey) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - exts.AddUint16(extensionCookie) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - exts.AddUint16(extensionKeyShare) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - exts.AddUint16(extensionSupportedPoints) - exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.supportedPoints) - }) - }) - } - - extBytes, err := exts.Bytes() - if err != nil { - return nil, err - } - - var b cryptobyte.Builder - b.AddUint8(typeServerHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16(m.cipherSuite) - b.AddUint8(m.compressionMethod) - - if len(extBytes) > 0 { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(extBytes) - }) - } - }) - - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *serverHelloMsg) unmarshal(data []byte) bool { - *m = serverHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) || - !s.ReadUint16(&m.cipherSuite) || - !s.ReadUint8(&m.compressionMethod) { - return false - } - - if s.Empty() { - // ServerHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - seenExts := make(map[uint16]bool) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - if seenExts[extension] { - return false - } - seenExts[extension] = true - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSessionTicket: - m.ticketSupported = true - case extensionRenegotiationInfo: - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - m.scts = append(m.scts, sct) - } - case extensionSupportedVersions: - if !extData.ReadUint16(&m.supportedVersion) { - return false - } - case extensionCookie: - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // This extension has different formats in SH and HRR, accept either - // and let the handshake logic decide. See RFC 8446, Section 4.2.8. - if len(extData) == 2 { - if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { - return false - } - } else { - if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || - !readUint16LengthPrefixed(&extData, &m.serverShare.data) { - return false - } - } - case extensionPreSharedKey: - m.selectedIdentityPresent = true - if !extData.ReadUint16(&m.selectedIdentity) { - return false - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type encryptedExtensionsMsg struct { - raw []byte - alpnProtocol string - quicTransportParameters []byte - earlyData bool -} - -func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeEncryptedExtensions) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if m.quicTransportParameters != nil { // marshal zero-length parameters when present - // draft-ietf-quic-tls-32, Section 8.2 - b.AddUint16(extensionQUICTransportParameters) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.quicTransportParameters) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { - *m = encryptedExtensionsMsg{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionQUICTransportParameters: - m.quicTransportParameters = make([]byte, len(extData)) - if !extData.CopyBytes(m.quicTransportParameters) { - return false - } - case extensionEarlyData: - m.earlyData = true - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type endOfEarlyDataMsg struct{} - -func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { - x := make([]byte, 4) - x[0] = typeEndOfEarlyData - return x, nil -} - -func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type keyUpdateMsg struct { - raw []byte - updateRequested bool -} - -func (m *keyUpdateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeKeyUpdate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.updateRequested { - b.AddUint8(1) - } else { - b.AddUint8(0) - } - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *keyUpdateMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var updateRequested uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&updateRequested) || !s.Empty() { - return false - } - switch updateRequested { - case 0: - m.updateRequested = false - case 1: - m.updateRequested = true - default: - return false - } - return true -} - -type newSessionTicketMsgTLS13 struct { - raw []byte - lifetime uint32 - ageAdd uint32 - nonce []byte - label []byte - maxEarlyData uint32 -} - -func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeNewSessionTicket) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.lifetime) - b.AddUint32(m.ageAdd) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.nonce) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.label) - }) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.maxEarlyData > 0 { - b.AddUint16(extensionEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.maxEarlyData) - }) - } - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { - *m = newSessionTicketMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint32(&m.lifetime) || - !s.ReadUint32(&m.ageAdd) || - !readUint8LengthPrefixed(&s, &m.nonce) || - !readUint16LengthPrefixed(&s, &m.label) || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionEarlyData: - if !extData.ReadUint32(&m.maxEarlyData) { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateRequestMsgTLS13 struct { - raw []byte - ocspStapling bool - scts bool - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateRequest) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - // certificate_request_context (SHALL be zero length unless used for - // post-handshake authentication) - b.AddUint8(0) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.scts { - // RFC 8446, Section 4.4.2.1 makes no mention of - // signed_certificate_timestamp in CertificateRequest, but - // "Extensions in the Certificate message from the client MUST - // correspond to extensions in the CertificateRequest message - // from the server." and it appears in the table in Section 4.2. - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedSignatureAlgorithms) > 0 { - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.certificateAuthorities) > 0 { - b.AddUint16(extensionCertificateAuthorities) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ca := range m.certificateAuthorities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ca) - }) - } - }) - }) - } - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { - *m = certificateRequestMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context, extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSCT: - m.scts = true - case extensionSignatureAlgorithms: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionCertificateAuthorities: - var auths cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { - return false - } - for !auths.Empty() { - var ca []byte - if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { - return false - } - m.certificateAuthorities = append(m.certificateAuthorities, ca) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateMsg struct { - raw []byte - certificates [][]byte -} - -func (m *certificateMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var i int - for _, slice := range m.certificates { - i += len(slice) - } - - length := 3 + 3*len(m.certificates) + i - x := make([]byte, 4+length) - x[0] = typeCertificate - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - certificateOctets := length - 3 - x[4] = uint8(certificateOctets >> 16) - x[5] = uint8(certificateOctets >> 8) - x[6] = uint8(certificateOctets) - - y := x[7:] - for _, slice := range m.certificates { - y[0] = uint8(len(slice) >> 16) - y[1] = uint8(len(slice) >> 8) - y[2] = uint8(len(slice)) - copy(y[3:], slice) - y = y[3+len(slice):] - } - - m.raw = x - return m.raw, nil -} - -func (m *certificateMsg) unmarshal(data []byte) bool { - if len(data) < 7 { - return false - } - - m.raw = data - certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) - if uint32(len(data)) != certsLen+7 { - return false - } - - numCerts := 0 - d := data[7:] - for certsLen > 0 { - if len(d) < 4 { - return false - } - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - if uint32(len(d)) < 3+certLen { - return false - } - d = d[3+certLen:] - certsLen -= 3 + certLen - numCerts++ - } - - m.certificates = make([][]byte, numCerts) - d = data[7:] - for i := 0; i < numCerts; i++ { - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - m.certificates[i] = d[3 : 3+certLen] - d = d[3+certLen:] - } - - return true -} - -type certificateMsgTLS13 struct { - raw []byte - certificate Certificate - ocspStapling bool - scts bool -} - -func (m *certificateMsgTLS13) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // certificate_request_context - - certificate := m.certificate - if !m.ocspStapling { - certificate.OCSPStaple = nil - } - if !m.scts { - certificate.SignedCertificateTimestamps = nil - } - marshalCertificate(b, certificate) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for i, cert := range certificate.Certificate { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if i > 0 { - // This library only supports OCSP and SCT for leaf certificates. - return - } - if certificate.OCSPStaple != nil { - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(certificate.OCSPStaple) - }) - }) - } - if certificate.SignedCertificateTimestamps != nil { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range certificate.SignedCertificateTimestamps { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - }) - } - }) -} - -func (m *certificateMsgTLS13) unmarshal(data []byte) bool { - *m = certificateMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !unmarshalCertificate(&s, &m.certificate) || - !s.Empty() { - return false - } - - m.scts = m.certificate.SignedCertificateTimestamps != nil - m.ocspStapling = m.certificate.OCSPStaple != nil - - return true -} - -func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - var extensions cryptobyte.String - if !readUint24LengthPrefixed(&certList, &cert) || - !certList.ReadUint16LengthPrefixed(&extensions) { - return false - } - certificate.Certificate = append(certificate.Certificate, cert) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - if len(certificate.Certificate) > 1 { - // This library only supports OCSP and SCT for leaf certificates. - continue - } - - switch extension { - case extensionStatusRequest: - var statusType uint8 - if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || - len(certificate.OCSPStaple) == 0 { - return false - } - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - certificate.SignedCertificateTimestamps = append( - certificate.SignedCertificateTimestamps, sct) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - } - return true -} - -type serverKeyExchangeMsg struct { - raw []byte - key []byte -} - -func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - length := len(m.key) - x := make([]byte, length+4) - x[0] = typeServerKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.key) - - m.raw = x - return x, nil -} - -func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - m.key = data[4:] - return true -} - -type certificateStatusMsg struct { - raw []byte - response []byte -} - -func (m *certificateStatusMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateStatus) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.response) - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *certificateStatusMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var statusType uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&s, &m.response) || - len(m.response) == 0 || !s.Empty() { - return false - } - return true -} - -type serverHelloDoneMsg struct{} - -func (m *serverHelloDoneMsg) marshal() ([]byte, error) { - x := make([]byte, 4) - x[0] = typeServerHelloDone - return x, nil -} - -func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type clientKeyExchangeMsg struct { - raw []byte - ciphertext []byte -} - -func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - length := len(m.ciphertext) - x := make([]byte, length+4) - x[0] = typeClientKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.ciphertext) - - m.raw = x - return x, nil -} - -func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if l != len(data)-4 { - return false - } - m.ciphertext = data[4:] - return true -} - -type finishedMsg struct { - raw []byte - verifyData []byte -} - -func (m *finishedMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeFinished) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.verifyData) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - return s.Skip(1) && - readUint24LengthPrefixed(&s, &m.verifyData) && - s.Empty() -} - -type certificateRequestMsg struct { - raw []byte - // hasSignatureAlgorithm indicates whether this message includes a list of - // supported signature algorithms. This change was introduced with TLS 1.2. - hasSignatureAlgorithm bool - - certificateTypes []byte - supportedSignatureAlgorithms []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - // See RFC 4346, Section 7.4.4. - length := 1 + len(m.certificateTypes) + 2 - casLength := 0 - for _, ca := range m.certificateAuthorities { - casLength += 2 + len(ca) - } - length += casLength - - if m.hasSignatureAlgorithm { - length += 2 + 2*len(m.supportedSignatureAlgorithms) - } - - x := make([]byte, 4+length) - x[0] = typeCertificateRequest - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - x[4] = uint8(len(m.certificateTypes)) - - copy(x[5:], m.certificateTypes) - y := x[5+len(m.certificateTypes):] - - if m.hasSignatureAlgorithm { - n := len(m.supportedSignatureAlgorithms) * 2 - y[0] = uint8(n >> 8) - y[1] = uint8(n) - y = y[2:] - for _, sigAlgo := range m.supportedSignatureAlgorithms { - y[0] = uint8(sigAlgo >> 8) - y[1] = uint8(sigAlgo) - y = y[2:] - } - } - - y[0] = uint8(casLength >> 8) - y[1] = uint8(casLength) - y = y[2:] - for _, ca := range m.certificateAuthorities { - y[0] = uint8(len(ca) >> 8) - y[1] = uint8(len(ca)) - y = y[2:] - copy(y, ca) - y = y[len(ca):] - } - - m.raw = x - return m.raw, nil -} - -func (m *certificateRequestMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 5 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - numCertTypes := int(data[4]) - data = data[5:] - if numCertTypes == 0 || len(data) <= numCertTypes { - return false - } - - m.certificateTypes = make([]byte, numCertTypes) - if copy(m.certificateTypes, data) != numCertTypes { - return false - } - - data = data[numCertTypes:] - - if m.hasSignatureAlgorithm { - if len(data) < 2 { - return false - } - sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if sigAndHashLen&1 != 0 { - return false - } - if len(data) < int(sigAndHashLen) { - return false - } - numSigAlgos := sigAndHashLen / 2 - m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) - for i := range m.supportedSignatureAlgorithms { - m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) - data = data[2:] - } - } - - if len(data) < 2 { - return false - } - casLength := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if len(data) < int(casLength) { - return false - } - cas := make([]byte, casLength) - copy(cas, data) - data = data[casLength:] - - m.certificateAuthorities = nil - for len(cas) > 0 { - if len(cas) < 2 { - return false - } - caLen := uint16(cas[0])<<8 | uint16(cas[1]) - cas = cas[2:] - - if len(cas) < int(caLen) { - return false - } - - m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) - cas = cas[caLen:] - } - - return len(data) == 0 -} - -type certificateVerifyMsg struct { - raw []byte - hasSignatureAlgorithm bool // format change introduced in TLS 1.2 - signatureAlgorithm SignatureScheme - signature []byte -} - -func (m *certificateVerifyMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateVerify) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.hasSignatureAlgorithm { - b.AddUint16(uint16(m.signatureAlgorithm)) - } - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.signature) - }) - }) - - var err error - m.raw, err = b.Bytes() - return m.raw, err -} - -func (m *certificateVerifyMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - if !s.Skip(4) { // message type and uint24 length field - return false - } - if m.hasSignatureAlgorithm { - if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { - return false - } - } - return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() -} - -type newSessionTicketMsg struct { - raw []byte - ticket []byte -} - -func (m *newSessionTicketMsg) marshal() ([]byte, error) { - if m.raw != nil { - return m.raw, nil - } - - // See RFC 5077, Section 3.3. - ticketLen := len(m.ticket) - length := 2 + 4 + ticketLen - x := make([]byte, 4+length) - x[0] = typeNewSessionTicket - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[8] = uint8(ticketLen >> 8) - x[9] = uint8(ticketLen) - copy(x[10:], m.ticket) - - m.raw = x - - return m.raw, nil -} - -func (m *newSessionTicketMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 10 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - ticketLen := int(data[8])<<8 + int(data[9]) - if len(data)-10 != ticketLen { - return false - } - - m.ticket = data[10:] - - return true -} - -type helloRequestMsg struct { -} - -func (*helloRequestMsg) marshal() ([]byte, error) { - return []byte{typeHelloRequest, 0, 0, 0}, nil -} - -func (*helloRequestMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type transcriptHash interface { - Write([]byte) (int, error) -} - -// transcriptMsg is a helper used to marshal and hash messages which typically -// are not written to the wire, and as such aren't hashed during Conn.writeRecord. -func transcriptMsg(msg handshakeMessage, h transcriptHash) error { - data, err := msg.marshal() - if err != nil { - return err - } - h.Write(data) - return nil -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/handshake_server.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_server.go deleted file mode 100644 index 7539c95d5..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/handshake_server.go +++ /dev/null @@ -1,899 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "time" -) - -// serverHandshakeState contains details of a server handshake in progress. -// It's discarded once the handshake has completed. -type serverHandshakeState struct { - c *Conn - ctx context.Context - clientHello *clientHelloMsg - hello *serverHelloMsg - suite *cipherSuite - ecdheOk bool - ecSignOk bool - rsaDecryptOk bool - rsaSignOk bool - sessionState *sessionState - finishedHash finishedHash - masterSecret []byte - cert *Certificate -} - -// serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake(ctx context.Context) error { - clientHello, err := c.readClientHello(ctx) - if err != nil { - return err - } - - if c.vers == VersionTLS13 { - hs := serverHandshakeStateTLS13{ - c: c, - ctx: ctx, - clientHello: clientHello, - } - return hs.handshake() - } - - hs := serverHandshakeState{ - c: c, - ctx: ctx, - clientHello: clientHello, - } - return hs.handshake() -} - -func (hs *serverHandshakeState) handshake() error { - c := hs.c - - if err := hs.processClientHello(); err != nil { - return err - } - - // For an overview of TLS handshaking, see RFC 5246, Section 7.3. - c.buffering = true - if hs.checkForResumption() { - // The client has included a session ticket and so we do an abbreviated handshake. - c.didResume = true - if err := hs.doResumeHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(c.serverFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.clientFinishedIsFirst = false - if err := hs.readFinished(nil); err != nil { - return err - } - } else { - // The client didn't include a session ticket, or it wasn't - // valid so we do a full handshake. - if err := hs.pickCipherSuite(); err != nil { - return err - } - if err := hs.doFullHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readFinished(c.clientFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = true - c.buffering = true - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(nil); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - c.isHandshakeComplete.Store(true) - - return nil -} - -// readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - // clientHelloMsg is included in the transcript, but we haven't initialized - // it yet. The respective handshake functions will record it themselves. - msg, err := c.readHandshake(nil) - if err != nil { - return nil, err - } - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return nil, unexpectedMessageError(clientHello, msg) - } - - var configForClient *config - originalConfig := c.config - if c.config.GetConfigForClient != nil { - chi := newClientHelloInfo(ctx, c, clientHello) - if cfc, err := c.config.GetConfigForClient(chi); err != nil { - c.sendAlert(alertInternalError) - return nil, err - } else if cfc != nil { - configForClient = fromConfig(cfc) - c.config = configForClient - } - } - c.ticketKeys = originalConfig.ticketKeys(configForClient) - - clientVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - clientVersions = supportedVersionsFromMax(clientHello.vers) - } - c.vers, ok = c.config.mutualVersion(roleServer, clientVersions) - if !ok { - c.sendAlert(alertProtocolVersion) - return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) - } - c.haveVers = true - c.in.version = c.vers - c.out.version = c.vers - - return clientHello, nil -} - -func (hs *serverHandshakeState) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.hello.vers = c.vers - - foundCompression := false - // We only support null compression, so check that the client offered it. - for _, compression := range hs.clientHello.compressionMethods { - if compression == compressionNone { - foundCompression = true - break - } - } - - if !foundCompression { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client does not support uncompressed connections") - } - - hs.hello.random = make([]byte, 32) - serverRandom := hs.hello.random - // Downgrade protection canaries. See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion(roleServer) - if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { - if c.vers == VersionTLS12 { - copy(serverRandom[24:], downgradeCanaryTLS12) - } else { - copy(serverRandom[24:], downgradeCanaryTLS11) - } - serverRandom = serverRandom[:24] - } - _, err := io.ReadFull(c.config.rand(), serverRandom) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported - hs.hello.compressionMethod = compressionNone - if len(hs.clientHello.serverName) > 0 { - c.serverName = hs.clientHello.serverName - } - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false) - if err != nil { - c.sendAlert(alertNoApplicationProtocol) - return err - } - hs.hello.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - if hs.clientHello.scts { - hs.hello.scts = hs.cert.SignedCertificateTimestamps - } - - hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) - - if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 { - // Although omitting the ec_point_formats extension is permitted, some - // old OpenSSL version will refuse to handshake if not present. - // - // Per RFC 4492, section 5.1.2, implementations MUST support the - // uncompressed point format. See golang.org/issue/31943. - hs.hello.supportedPoints = []uint8{pointFormatUncompressed} - } - - if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { - switch priv.Public().(type) { - case *ecdsa.PublicKey: - hs.ecSignOk = true - case ed25519.PublicKey: - hs.ecSignOk = true - case *rsa.PublicKey: - hs.rsaSignOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) - } - } - if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { - switch priv.Public().(type) { - case *rsa.PublicKey: - hs.rsaDecryptOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) - } - } - - return nil -} - -// negotiateALPN picks a shared ALPN protocol that both sides support in server -// preference order. If ALPN is not configured or the peer doesn't support it, -// it returns "" and no error. -func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) { - if len(serverProtos) == 0 || len(clientProtos) == 0 { - if quic && len(serverProtos) != 0 { - // RFC 9001, Section 8.1 - return "", fmt.Errorf("tls: client did not request an application protocol") - } - return "", nil - } - var http11fallback bool - for _, s := range serverProtos { - for _, c := range clientProtos { - if s == c { - return s, nil - } - if s == "h2" && c == "http/1.1" { - http11fallback = true - } - } - } - // As a special case, let http/1.1 clients connect to h2 servers as if they - // didn't support ALPN. We used not to enforce protocol overlap, so over - // time a number of HTTP servers were configured with only "h2", but - // expected to accept connections from "http/1.1" clients. See Issue 46310. - if http11fallback { - return "", nil - } - return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) -} - -// supportsECDHE returns whether ECDHE key exchanges can be used with this -// pre-TLS 1.3 client. -func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { - supportsCurve := false - for _, curve := range supportedCurves { - if c.supportsCurve(curve) { - supportsCurve = true - break - } - } - - supportsPointFormat := false - for _, pointFormat := range supportedPoints { - if pointFormat == pointFormatUncompressed { - supportsPointFormat = true - break - } - } - // Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is - // missing, uncompressed points are supported. If supportedPoints is empty, - // the extension must be missing, as an empty extension body is rejected by - // the parser. See https://go.dev/issue/49126. - if len(supportedPoints) == 0 { - supportsPointFormat = true - } - - return supportsCurve && supportsPointFormat -} - -func (hs *serverHandshakeState) pickCipherSuite() error { - c := hs.c - - preferenceOrder := cipherSuitesPreferenceOrder - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceOrder = cipherSuitesPreferenceOrderNoAES - } - - configCipherSuites := c.config.cipherSuites() - preferenceList := make([]uint16, 0, len(configCipherSuites)) - for _, suiteID := range preferenceOrder { - for _, id := range configCipherSuites { - if id == suiteID { - preferenceList = append(preferenceList, id) - break - } - } - } - - hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // The client is doing a fallback connection. See RFC 7507. - if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - return nil -} - -func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - if !hs.ecdheOk { - return false - } - if c.flags&suiteECSign != 0 { - if !hs.ecSignOk { - return false - } - } else if !hs.rsaSignOk { - return false - } - } else if !hs.rsaDecryptOk { - return false - } - if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true -} - -// checkForResumption reports whether we should perform resumption on this connection. -func (hs *serverHandshakeState) checkForResumption() bool { - c := hs.c - - if c.config.SessionTicketsDisabled { - return false - } - - plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) - if plaintext == nil { - return false - } - hs.sessionState = &sessionState{usedOldKey: usedOldKey} - ok := hs.sessionState.unmarshal(plaintext) - if !ok { - return false - } - - createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - return false - } - - // Never resume a session for a different TLS version. - if c.vers != hs.sessionState.vers { - return false - } - - cipherSuiteOk := false - // Check that the client is still offering the ciphersuite in the session. - for _, id := range hs.clientHello.cipherSuites { - if id == hs.sessionState.cipherSuite { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return false - } - - // Check that we also support the ciphersuite from the session. - hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, - c.config.cipherSuites(), hs.cipherSuiteOk) - if hs.suite == nil { - return false - } - - sessionHasClientCerts := len(hs.sessionState.certificates) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - return false - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - return false - } - - return true -} - -func (hs *serverHandshakeState) doResumeHandshake() error { - c := hs.c - - hs.hello.cipherSuite = hs.suite.id - c.cipherSuite = hs.suite.id - // We echo the client's session ID in the ServerHello to let it know - // that we're doing a resumption. - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.ticketSupported = hs.sessionState.usedOldKey - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - hs.finishedHash.discardHandshakeBuffer() - if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { - return err - } - if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { - return err - } - - if err := c.processCertsFromClient(Certificate{ - Certificate: hs.sessionState.certificates, - }); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - hs.masterSecret = hs.sessionState.masterSecret - - return nil -} - -func (hs *serverHandshakeState) doFullHandshake() error { - c := hs.c - - if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { - hs.hello.ocspStapling = true - } - - hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled - hs.hello.cipherSuite = hs.suite.id - - hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) - if c.config.ClientAuth == NoClientCert { - // No need to keep a full record of the handshake if client - // certificates won't be used. - hs.finishedHash.discardHandshakeBuffer() - } - if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { - return err - } - if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { - return err - } - - certMsg := new(certificateMsg) - certMsg.certificates = hs.cert.Certificate - if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { - return err - } - - if hs.hello.ocspStapling { - certStatus := new(certificateStatusMsg) - certStatus.response = hs.cert.OCSPStaple - if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { - return err - } - } - - keyAgreement := hs.suite.ka(c.vers) - skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - if skx != nil { - if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { - return err - } - } - - var certReq *certificateRequestMsg - if c.config.ClientAuth >= RequestClientCert { - // Request a client certificate - certReq = new(certificateRequestMsg) - certReq.certificateTypes = []byte{ - byte(certTypeRSASign), - byte(certTypeECDSASign), - } - if c.vers >= VersionTLS12 { - certReq.hasSignatureAlgorithm = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - } - - // An empty list of certificateAuthorities signals to - // the client that it may send any certificate in response - // to our request. When we know the CAs we trust, then - // we can send them down, so that the client can choose - // an appropriate certificate to give to us. - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil { - return err - } - } - - helloDone := new(serverHelloDoneMsg) - if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { - return err - } - - if _, err := c.flush(); err != nil { - return err - } - - var pub crypto.PublicKey // public key for client auth, if any - - msg, err := c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - - // If we requested a client certificate, then the client must send a - // certificate message, even if it's empty. - if c.config.ClientAuth >= RequestClientCert { - certMsg, ok := msg.(*certificateMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - - if err := c.processCertsFromClient(Certificate{ - Certificate: certMsg.certificates, - }); err != nil { - return err - } - if len(certMsg.certificates) != 0 { - pub = c.peerCertificates[0].PublicKey - } - - msg, err = c.readHandshake(&hs.finishedHash) - if err != nil { - return err - } - } - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - // Get client key exchange - ckx, ok := msg.(*clientKeyExchangeMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(ckx, msg) - } - - preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - - // If we received a client cert in response to our certificate request message, - // the client will send us a certificateVerifyMsg immediately after the - // clientKeyExchangeMsg. This message is a digest of all preceding - // handshake-layer messages that is signed using the private key corresponding - // to the client's certificate. This allows us to verify that the client is in - // possession of the private key of the certificate. - if len(c.peerCertificates) > 0 { - // certificateVerifyMsg is included in the transcript, but not until - // after we verify the handshake signature, since the state before - // this message was sent is used. - msg, err = c.readHandshake(nil) - if err != nil { - return err - } - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) - if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { - return err - } - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *serverHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - - var clientCipher, serverCipher any - var clientHash, serverHash hash.Hash - - if hs.suite.aead == nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) - c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) - - return nil -} - -func (hs *serverHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - // finishedMsg is included in the transcript, but not until after we - // check the client version, since the state before this message was - // sent is used during verification. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - clientFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientFinished, msg) - } - - verify := hs.finishedHash.clientSum(hs.masterSecret) - if len(verify) != len(clientFinished.verifyData) || - subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client's Finished message is incorrect") - } - - if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { - return err - } - - copy(out, verify) - return nil -} - -func (hs *serverHandshakeState) sendSessionTicket() error { - // ticketSupported is set in a resumption handshake if the - // ticket from the client was encrypted with an old session - // ticket key and thus a refreshed ticket should be sent. - if !hs.hello.ticketSupported { - return nil - } - - c := hs.c - m := new(newSessionTicketMsg) - - createdAt := uint64(c.config.time().Unix()) - if hs.sessionState != nil { - // If this is re-wrapping an old key, then keep - // the original time it was created. - createdAt = hs.sessionState.createdAt - } - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionState{ - vers: c.vers, - cipherSuite: hs.suite.id, - createdAt: createdAt, - masterSecret: hs.masterSecret, - certificates: certsFromClient, - } - stateBytes, err := state.marshal() - if err != nil { - return err - } - m.ticket, err = c.encryptTicket(stateBytes) - if err != nil { - return err - } - - if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if err := c.writeChangeCipherRecord(); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { - return err - } - - copy(out, finished.verifyData) - - return nil -} - -// processCertsFromClient takes a chain of client certificates either from a -// Certificates message or from a sessionState and verifies them. It returns -// the public key of the leaf certificate. -func (c *Conn) processCertsFromClient(certificate Certificate) error { - certificates := certificate.Certificate - certs := make([]*x509.Certificate, len(certificates)) - var err error - for i, asn1Data := range certificates { - if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse client certificate: " + err.Error()) - } - if certs[i].PublicKeyAlgorithm == x509.RSA && certs[i].PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize { - c.sendAlert(alertBadCertificate) - return fmt.Errorf("tls: client sent certificate containing RSA key larger than %d bits", maxRSAKeySize) - } - } - - if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: client didn't provide a certificate") - } - - if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { - opts := x509.VerifyOptions{ - Roots: c.config.ClientCAs, - CurrentTime: c.config.time(), - Intermediates: x509.NewCertPool(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - - chains, err := certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} - } - - c.verifiedChains = chains - } - - c.peerCertificates = certs - c.ocspResponse = certificate.OCSPStaple - c.scts = certificate.SignedCertificateTimestamps - - if len(certs) > 0 { - switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) - } - } - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { - supportedVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - supportedVersions = supportedVersionsFromMax(clientHello.vers) - } - - return toClientHelloInfo(&clientHelloInfo{ - CipherSuites: clientHello.cipherSuites, - ServerName: clientHello.serverName, - SupportedCurves: clientHello.supportedCurves, - SupportedPoints: clientHello.supportedPoints, - SignatureSchemes: clientHello.supportedSignatureAlgorithms, - SupportedProtos: clientHello.alpnProtocols, - SupportedVersions: supportedVersions, - Conn: c.conn, - config: toConfig(c.config), - ctx: ctx, - }) -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go deleted file mode 100644 index 03bc5eb51..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go +++ /dev/null @@ -1,979 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "context" - "crypto" - "crypto/hmac" - "crypto/rsa" - "errors" - "hash" - "io" - "time" -) - -// maxClientPSKIdentities is the number of client PSK identities the server will -// attempt to validate. It will ignore the rest not to let cheap ClientHello -// messages cause too much work in session ticket decryption attempts. -const maxClientPSKIdentities = 5 - -type serverHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - clientHello *clientHelloMsg - hello *serverHelloMsg - alpnNegotiationErr error - encryptedExtensions *encryptedExtensionsMsg - sentDummyCCS bool - usingPSK bool - suite *cipherSuiteTLS13 - cert *Certificate - sigAlg SignatureScheme - earlySecret []byte - sharedKey []byte - handshakeSecret []byte - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 - transcript hash.Hash - clientFinished []byte - earlyData bool -} - -func (hs *serverHandshakeStateTLS13) handshake() error { - c := hs.c - - if needFIPS() { - return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") - } - - // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. - if err := hs.processClientHello(); err != nil { - return err - } - if err := hs.checkForResumption(); err != nil { - return err - } - if err := hs.pickCertificate(); err != nil { - return err - } - c.buffering = true - if err := hs.sendServerParameters(); err != nil { - return err - } - if err := hs.sendServerCertificate(); err != nil { - return err - } - if err := hs.sendServerFinished(); err != nil { - return err - } - // Note that at this point we could start sending application data without - // waiting for the client's second flight, but the application might not - // expect the lack of replay protection of the ClientHello parameters. - if _, err := c.flush(); err != nil { - return err - } - if err := hs.readClientCertificate(); err != nil { - return err - } - if err := hs.readClientFinished(); err != nil { - return err - } - - c.isHandshakeComplete.Store(true) - - return nil -} - -func (hs *serverHandshakeStateTLS13) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.encryptedExtensions = new(encryptedExtensionsMsg) - - // TLS 1.3 froze the ServerHello.legacy_version field, and uses - // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. - hs.hello.vers = VersionTLS12 - hs.hello.supportedVersion = c.vers - - if len(hs.clientHello.supportedVersions) == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") - } - - // Abort if the client is doing a fallback and landing lower than what we - // support. See RFC 7507, which however does not specify the interaction - // with supported_versions. The only difference is that with - // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] - // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, - // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to - // TLS 1.2, because a TLS 1.3 server would abort here. The situation before - // supported_versions was not better because there was just no way to do a - // TLS 1.4 handshake without risking the server selecting TLS 1.3. - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // Use c.vers instead of max(supported_versions) because an attacker - // could defeat this by adding an arbitrary high version otherwise. - if c.vers < c.config.maxSupportedVersion(roleServer) { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - if len(hs.clientHello.compressionMethods) != 1 || - hs.clientHello.compressionMethods[0] != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: TLS 1.3 client supports illegal compression methods") - } - - hs.hello.random = make([]byte, 32) - if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.compressionMethod = compressionNone - - preferenceList := defaultCipherSuitesTLS13 - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = defaultCipherSuitesTLS13NoAES - } - for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) - if hs.suite != nil { - break - } - } - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - hs.hello.cipherSuite = hs.suite.id - hs.transcript = hs.suite.hash.New() - - // Pick the ECDHE group in server preference order, but give priority to - // groups with a key share, to avoid a HelloRetryRequest round-trip. - var selectedGroup CurveID - var clientKeyShare *keyShare -GroupSelection: - for _, preferredGroup := range c.config.curvePreferences() { - for _, ks := range hs.clientHello.keyShares { - if ks.group == preferredGroup { - selectedGroup = ks.group - clientKeyShare = &ks - break GroupSelection - } - } - if selectedGroup != 0 { - continue - } - for _, group := range hs.clientHello.supportedCurves { - if group == preferredGroup { - selectedGroup = group - break - } - } - } - if selectedGroup == 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no ECDHE curve supported by both client and server") - } - if clientKeyShare == nil { - if err := hs.doHelloRetryRequest(selectedGroup); err != nil { - return err - } - clientKeyShare = &hs.clientHello.keyShares[0] - } - - if _, ok := curveForCurveID(selectedGroup); !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err := generateECDHEKey(c.config.rand(), selectedGroup) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} - peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - hs.sharedKey, err = key.ECDH(peerKey) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - - if c.quic != nil { - if hs.clientHello.quicTransportParameters == nil { - // RFC 9001 Section 8.2. - c.sendAlert(alertMissingExtension) - return errors.New("tls: client did not send a quic_transport_parameters extension") - } - c.quicSetTransportParameters(hs.clientHello.quicTransportParameters) - } else { - if hs.clientHello.quicTransportParameters != nil { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: client sent an unexpected quic_transport_parameters extension") - } - } - - c.serverName = hs.clientHello.serverName - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil) - if err != nil { - hs.alpnNegotiationErr = err - } - hs.encryptedExtensions.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - return nil -} - -func (hs *serverHandshakeStateTLS13) checkForResumption() error { - c := hs.c - - if c.config.SessionTicketsDisabled { - return nil - } - - modeOK := false - for _, mode := range hs.clientHello.pskModes { - if mode == pskModeDHE { - modeOK = true - break - } - } - if !modeOK { - return nil - } - - if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid or missing PSK binders") - } - if len(hs.clientHello.pskIdentities) == 0 { - return nil - } - - for i, identity := range hs.clientHello.pskIdentities { - if i >= maxClientPSKIdentities { - break - } - - plaintext, _ := c.decryptTicket(identity.label) - if plaintext == nil { - continue - } - sessionState := new(sessionStateTLS13) - if ok := sessionState.unmarshal(plaintext); !ok { - continue - } - - if hs.clientHello.earlyData { - if sessionState.maxEarlyData == 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: client sent unexpected early data") - } - - if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol && - c.extraConfig != nil && c.extraConfig.Enable0RTT && - c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { - hs.encryptedExtensions.earlyData = true - } - } - - createdAt := time.Unix(int64(sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - continue - } - - // We don't check the obfuscated ticket age because it's affected by - // clock skew and it's only a freshness signal useful for shrinking the - // window for replay attacks, which don't affect us as we don't do 0-RTT. - - pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) - if pskSuite == nil || pskSuite.hash != hs.suite.hash { - continue - } - - // PSK connections don't re-establish client certificates, but carry - // them over in the session ticket. Ensure the presence of client certs - // in the ticket is consistent with the configured requirements. - sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - continue - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - continue - } - - psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", - nil, hs.suite.hash.Size()) - hs.earlySecret = hs.suite.extract(psk, nil) - binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) - // Clone the transcript in case a HelloRetryRequest was recorded. - transcript := cloneHash(hs.transcript, hs.suite.hash) - if transcript == nil { - c.sendAlert(alertInternalError) - return errors.New("tls: internal error: failed to clone hash") - } - clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() - if err != nil { - c.sendAlert(alertInternalError) - return err - } - transcript.Write(clientHelloBytes) - pskBinder := hs.suite.finishedHash(binderKey, transcript) - if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid PSK binder") - } - - if c.quic != nil && hs.clientHello.earlyData && hs.encryptedExtensions.earlyData && i == 0 && - sessionState.maxEarlyData > 0 && sessionState.cipherSuite == hs.suite.id { - hs.earlyData = true - - transcript := hs.suite.hash.New() - if err := transcriptMsg(hs.clientHello, transcript); err != nil { - return err - } - earlyTrafficSecret := hs.suite.deriveSecret(hs.earlySecret, clientEarlyTrafficLabel, transcript) - c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret) - } - - c.didResume = true - if err := c.processCertsFromClient(sessionState.certificate); err != nil { - return err - } - - hs.hello.selectedIdentityPresent = true - hs.hello.selectedIdentity = uint16(i) - hs.usingPSK = true - return nil - } - - return nil -} - -// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler -// interfaces implemented by standard library hashes to clone the state of in -// to a new instance of h. It returns nil if the operation fails. -func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { - // Recreate the interface to avoid importing encoding. - type binaryMarshaler interface { - MarshalBinary() (data []byte, err error) - UnmarshalBinary(data []byte) error - } - marshaler, ok := in.(binaryMarshaler) - if !ok { - return nil - } - state, err := marshaler.MarshalBinary() - if err != nil { - return nil - } - out := h.New() - unmarshaler, ok := out.(binaryMarshaler) - if !ok { - return nil - } - if err := unmarshaler.UnmarshalBinary(state); err != nil { - return nil - } - return out -} - -func (hs *serverHandshakeStateTLS13) pickCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. - if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { - return c.sendAlert(alertMissingExtension) - } - - certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) - if err != nil { - // getCertificate returned a certificate that is unsupported or - // incompatible with the client's signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - hs.cert = certificate - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.c.quic != nil { - return nil - } - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - return hs.c.writeChangeCipherRecord() -} - -func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. See RFC 8446, Section 4.4.1. - if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { - return err - } - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - - helloRetryRequest := &serverHelloMsg{ - vers: hs.hello.vers, - random: helloRetryRequestRandom, - sessionId: hs.hello.sessionId, - cipherSuite: hs.hello.cipherSuite, - compressionMethod: hs.hello.compressionMethod, - supportedVersion: hs.hello.supportedVersion, - selectedGroup: selectedGroup, - } - - if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - // clientHelloMsg is not included in the transcript. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientHello, msg) - } - - if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client sent invalid key share in second ClientHello") - } - - if clientHello.earlyData { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client indicated early data in second ClientHello") - } - - if illegalClientHelloChange(clientHello, hs.clientHello) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client illegally modified second ClientHello") - } - - if illegalClientHelloChange(clientHello, hs.clientHello) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client illegally modified second ClientHello") - } - - hs.clientHello = clientHello - return nil -} - -// illegalClientHelloChange reports whether the two ClientHello messages are -// different, with the exception of the changes allowed before and after a -// HelloRetryRequest. See RFC 8446, Section 4.1.2. -func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { - if len(ch.supportedVersions) != len(ch1.supportedVersions) || - len(ch.cipherSuites) != len(ch1.cipherSuites) || - len(ch.supportedCurves) != len(ch1.supportedCurves) || - len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || - len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || - len(ch.alpnProtocols) != len(ch1.alpnProtocols) { - return true - } - for i := range ch.supportedVersions { - if ch.supportedVersions[i] != ch1.supportedVersions[i] { - return true - } - } - for i := range ch.cipherSuites { - if ch.cipherSuites[i] != ch1.cipherSuites[i] { - return true - } - } - for i := range ch.supportedCurves { - if ch.supportedCurves[i] != ch1.supportedCurves[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithms { - if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithmsCert { - if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { - return true - } - } - for i := range ch.alpnProtocols { - if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { - return true - } - } - return ch.vers != ch1.vers || - !bytes.Equal(ch.random, ch1.random) || - !bytes.Equal(ch.sessionId, ch1.sessionId) || - !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || - ch.serverName != ch1.serverName || - ch.ocspStapling != ch1.ocspStapling || - !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || - ch.ticketSupported != ch1.ticketSupported || - !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || - ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || - !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || - ch.scts != ch1.scts || - !bytes.Equal(ch.cookie, ch1.cookie) || - !bytes.Equal(ch.pskModes, ch1.pskModes) -} - -func (hs *serverHandshakeStateTLS13) sendServerParameters() error { - c := hs.c - - if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { - return err - } - if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - earlySecret := hs.earlySecret - if earlySecret == nil { - earlySecret = hs.suite.extract(nil, nil) - } - hs.handshakeSecret = hs.suite.extract(hs.sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) - serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) - - if c.quic != nil { - if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) - } - c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret) - c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret) - } - - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil) - if err != nil { - c.sendAlert(alertNoApplicationProtocol) - return err - } - hs.encryptedExtensions.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - if c.quic != nil { - p, err := c.quicGetTransportParameters() - if err != nil { - return err - } - hs.encryptedExtensions.quicTransportParameters = p - } - - if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) requestClientCert() bool { - return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK -} - -func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - if hs.requestClientCert() { - // Request a client certificate - certReq := new(certificateRequestMsgTLS13) - certReq.ocspStapling = true - certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - - if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { - return err - } - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *hs.cert - certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - - if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { - return err - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - certVerifyMsg.signatureAlgorithm = hs.sigAlg - - sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - public := hs.cert.PrivateKey.(crypto.Signer).Public() - if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && - rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS - c.sendAlert(alertHandshakeFailure) - } else { - c.sendAlert(alertInternalError) - } - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) sendServerFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { - return err - } - - // Derive secrets that take context through the server Finished. - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) - - if c.quic != nil { - if c.hand.Len() != 0 { - // TODO: Handle this in setTrafficSecret? - c.sendAlert(alertUnexpectedMessage) - } - c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret) - } - - err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - // If we did not request client certificates, at this point we can - // precompute the client finished and roll the transcript forward to send - // session tickets in our first flight. - if !hs.requestClientCert() { - if err := hs.sendSessionTickets(); err != nil { - return err - } - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { - if hs.c.config.SessionTicketsDisabled { - return false - } - - // QUIC tickets are sent by QUICConn.SendSessionTicket, not automatically. - if hs.c.quic != nil { - return false - } - // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. - for _, pskMode := range hs.clientHello.pskModes { - if pskMode == pskModeDHE { - return true - } - } - return false -} - -func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { - c := hs.c - - hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - finishedMsg := &finishedMsg{ - verifyData: hs.clientFinished, - } - if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { - return err - } - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - - if !hs.shouldSendSessionTickets() { - return nil - } - return c.sendSessionTicket(false) -} - -func (c *Conn) sendSessionTicket(earlyData bool) error { - suite := cipherSuiteTLS13ByID(c.cipherSuite) - if suite == nil { - return errors.New("tls: internal error: unknown cipher suite") - } - - m := new(newSessionTicketMsgTLS13) - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionStateTLS13{ - cipherSuite: suite.id, - createdAt: uint64(c.config.time().Unix()), - resumptionSecret: c.resumptionSecret, - certificate: Certificate{ - Certificate: certsFromClient, - OCSPStaple: c.ocspResponse, - SignedCertificateTimestamps: c.scts, - }, - alpn: c.clientProtocol, - } - if earlyData { - state.maxEarlyData = 0xffffffff - state.appData = c.extraConfig.GetAppDataForSessionTicket() - } - stateBytes, err := state.marshal() - if err != nil { - c.sendAlert(alertInternalError) - return err - } - m.label, err = c.encryptTicket(stateBytes) - if err != nil { - return err - } - m.lifetime = uint32(maxSessionTicketLifetime / time.Second) - - // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 - // The value is not stored anywhere; we never need to check the ticket age - // because 0-RTT is not supported. - ageAdd := make([]byte, 4) - _, err = c.config.rand().Read(ageAdd) - if err != nil { - return err - } - - if earlyData { - // RFC 9001, Section 4.6.1 - m.maxEarlyData = 0xffffffff - } - - if _, err := c.writeHandshakeRecord(m, nil); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientCertificate() error { - c := hs.c - - if !hs.requestClientCert() { - // Make sure the connection is still being verified whether or not - // the server requested a client certificate. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - // If we requested a client certificate, then the client must send a - // certificate message. If it's empty, no CertificateVerify is sent. - - msg, err := c.readHandshake(hs.transcript) - if err != nil { - return err - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - - if err := c.processCertsFromClient(certMsg.certificate); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if len(certMsg.certificate.Certificate) != 0 { - // certificateVerifyMsg is included in the transcript, but not until - // after we verify the handshake signature, since the state before - // this message was sent is used. - msg, err = c.readHandshake(nil) - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - if err := transcriptMsg(certVerify, hs.transcript); err != nil { - return err - } - } - - // If we waited until the client certificates to send session tickets, we - // are ready to do it now. - if err := hs.sendSessionTickets(); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientFinished() error { - c := hs.c - - // finishedMsg is not included in the transcript. - msg, err := c.readHandshake(nil) - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - if !hmac.Equal(hs.clientFinished, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid client finished hash") - } - - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) - - return nil -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/key_agreement.go b/vendor/github.com/quic-go/qtls-go1-20/key_agreement.go deleted file mode 100644 index f926869a1..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/key_agreement.go +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/ecdh" - "crypto/md5" - "crypto/rsa" - "crypto/sha1" - "crypto/x509" - "errors" - "fmt" - "io" -) - -// a keyAgreement implements the client and server side of a TLS key agreement -// protocol by generating and processing key exchange messages. -type keyAgreement interface { - // On the server side, the first two methods are called in order. - - // In the case that the key agreement protocol doesn't use a - // ServerKeyExchange message, generateServerKeyExchange can return nil, - // nil. - generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) - processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) - - // On the client side, the next two methods are called in order. - - // This method may not be called if the server doesn't send a - // ServerKeyExchange message. - processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error - generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) -} - -var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") -var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") - -// rsaKeyAgreement implements the standard TLS key agreement where the client -// encrypts the pre-master secret to the server's public key. -type rsaKeyAgreement struct{} - -func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - return nil, nil -} - -func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) < 2 { - return nil, errClientKeyExchange - } - ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) - if ciphertextLen != len(ckx.ciphertext)-2 { - return nil, errClientKeyExchange - } - ciphertext := ckx.ciphertext[2:] - - priv, ok := cert.PrivateKey.(crypto.Decrypter) - if !ok { - return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") - } - // Perform constant time RSA PKCS #1 v1.5 decryption - preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) - if err != nil { - return nil, err - } - // We don't check the version number in the premaster secret. For one, - // by checking it, we would leak information about the validity of the - // encrypted pre-master secret. Secondly, it provides only a small - // benefit against a downgrade attack and some implementations send the - // wrong version anyway. See the discussion at the end of section - // 7.4.7.1 of RFC 4346. - return preMasterSecret, nil -} - -func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - return errors.New("tls: unexpected ServerKeyExchange") -} - -func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - preMasterSecret := make([]byte, 48) - preMasterSecret[0] = byte(clientHello.vers >> 8) - preMasterSecret[1] = byte(clientHello.vers) - _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) - if err != nil { - return nil, nil, err - } - - rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") - } - encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) - if err != nil { - return nil, nil, err - } - ckx := new(clientKeyExchangeMsg) - ckx.ciphertext = make([]byte, len(encrypted)+2) - ckx.ciphertext[0] = byte(len(encrypted) >> 8) - ckx.ciphertext[1] = byte(len(encrypted)) - copy(ckx.ciphertext[2:], encrypted) - return preMasterSecret, ckx, nil -} - -// sha1Hash calculates a SHA1 hash over the given byte slices. -func sha1Hash(slices [][]byte) []byte { - hsha1 := sha1.New() - for _, slice := range slices { - hsha1.Write(slice) - } - return hsha1.Sum(nil) -} - -// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the -// concatenation of an MD5 and SHA1 hash. -func md5SHA1Hash(slices [][]byte) []byte { - md5sha1 := make([]byte, md5.Size+sha1.Size) - hmd5 := md5.New() - for _, slice := range slices { - hmd5.Write(slice) - } - copy(md5sha1, hmd5.Sum(nil)) - copy(md5sha1[md5.Size:], sha1Hash(slices)) - return md5sha1 -} - -// hashForServerKeyExchange hashes the given slices and returns their digest -// using the given hash function (for >= TLS 1.2) or using a default based on -// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't -// do pre-hashing, it returns the concatenation of the slices. -func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { - if sigType == signatureEd25519 { - var signed []byte - for _, slice := range slices { - signed = append(signed, slice...) - } - return signed - } - if version >= VersionTLS12 { - h := hashFunc.New() - for _, slice := range slices { - h.Write(slice) - } - digest := h.Sum(nil) - return digest - } - if sigType == signatureECDSA { - return sha1Hash(slices) - } - return md5SHA1Hash(slices) -} - -// ecdheKeyAgreement implements a TLS key agreement where the server -// generates an ephemeral EC public/private key pair and signs it. The -// pre-master secret is then calculated using ECDH. The signature may -// be ECDSA, Ed25519 or RSA. -type ecdheKeyAgreement struct { - version uint16 - isRSA bool - key *ecdh.PrivateKey - - // ckx and preMasterSecret are generated in processServerKeyExchange - // and returned in generateClientKeyExchange. - ckx *clientKeyExchangeMsg - preMasterSecret []byte -} - -func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - var curveID CurveID - for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { - curveID = c - break - } - } - - if curveID == 0 { - return nil, errors.New("tls: no supported elliptic curves offered") - } - if _, ok := curveForCurveID(curveID); !ok { - return nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - - key, err := generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, err - } - ka.key = key - - // See RFC 4492, Section 5.4. - ecdhePublic := key.PublicKey().Bytes() - serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) - serverECDHEParams[0] = 3 // named curve - serverECDHEParams[1] = byte(curveID >> 8) - serverECDHEParams[2] = byte(curveID) - serverECDHEParams[3] = byte(len(ecdhePublic)) - copy(serverECDHEParams[4:], ecdhePublic) - - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) - } - - var signatureAlgorithm SignatureScheme - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) - if err != nil { - return nil, err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return nil, err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) - if err != nil { - return nil, err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") - } - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) - - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := priv.Sign(config.rand(), signed, signOpts) - if err != nil { - return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) - } - - skx := new(serverKeyExchangeMsg) - sigAndHashLen := 0 - if ka.version >= VersionTLS12 { - sigAndHashLen = 2 - } - skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) - copy(skx.key, serverECDHEParams) - k := skx.key[len(serverECDHEParams):] - if ka.version >= VersionTLS12 { - k[0] = byte(signatureAlgorithm >> 8) - k[1] = byte(signatureAlgorithm) - k = k[2:] - } - k[0] = byte(len(sig) >> 8) - k[1] = byte(len(sig)) - copy(k[2:], sig) - - return skx, nil -} - -func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { - return nil, errClientKeyExchange - } - - peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:]) - if err != nil { - return nil, errClientKeyExchange - } - preMasterSecret, err := ka.key.ECDH(peerKey) - if err != nil { - return nil, errClientKeyExchange - } - - return preMasterSecret, nil -} - -func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - if len(skx.key) < 4 { - return errServerKeyExchange - } - if skx.key[0] != 3 { // named curve - return errors.New("tls: server selected unsupported curve") - } - curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) - - publicLen := int(skx.key[3]) - if publicLen+4 > len(skx.key) { - return errServerKeyExchange - } - serverECDHEParams := skx.key[:4+publicLen] - publicKey := serverECDHEParams[4:] - - sig := skx.key[4+publicLen:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if _, ok := curveForCurveID(curveID); !ok { - return errors.New("tls: server selected unsupported curve") - } - - key, err := generateECDHEKey(config.rand(), curveID) - if err != nil { - return err - } - ka.key = key - - peerKey, err := key.Curve().NewPublicKey(publicKey) - if err != nil { - return errServerKeyExchange - } - ka.preMasterSecret, err = key.ECDH(peerKey) - if err != nil { - return errServerKeyExchange - } - - ourPublicKey := key.PublicKey().Bytes() - ka.ckx = new(clientKeyExchangeMsg) - ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) - ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) - copy(ka.ckx.ciphertext[1:], ourPublicKey) - - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) - sig = sig[2:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) - if err != nil { - return err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return errServerKeyExchange - } - - sigLen := int(sig[0])<<8 | int(sig[1]) - if sigLen+2 != len(sig) { - return errServerKeyExchange - } - sig = sig[2:] - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) - if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - return nil -} - -func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - if ka.ckx == nil { - return nil, nil, errors.New("tls: missing ServerKeyExchange message") - } - - return ka.preMasterSecret, ka.ckx, nil -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/key_schedule.go b/vendor/github.com/quic-go/qtls-go1-20/key_schedule.go deleted file mode 100644 index a4568933d..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/key_schedule.go +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto/ecdh" - "crypto/hmac" - "errors" - "fmt" - "hash" - "io" - - "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/hkdf" -) - -// This file contains the functions necessary to compute the TLS 1.3 key -// schedule. See RFC 8446, Section 7. - -const ( - resumptionBinderLabel = "res binder" - clientEarlyTrafficLabel = "c e traffic" - clientHandshakeTrafficLabel = "c hs traffic" - serverHandshakeTrafficLabel = "s hs traffic" - clientApplicationTrafficLabel = "c ap traffic" - serverApplicationTrafficLabel = "s ap traffic" - exporterLabel = "exp master" - resumptionLabel = "res master" - trafficUpdateLabel = "traffic upd" -) - -// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { - var hkdfLabel cryptobyte.Builder - hkdfLabel.AddUint16(uint16(length)) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte("tls13 ")) - b.AddBytes([]byte(label)) - }) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(context) - }) - hkdfLabelBytes, err := hkdfLabel.Bytes() - if err != nil { - // Rather than calling BytesOrPanic, we explicitly handle this error, in - // order to provide a reasonable error message. It should be basically - // impossible for this to panic, and routing errors back through the - // tree rooted in this function is quite painful. The labels are fixed - // size, and the context is either a fixed-length computed hash, or - // parsed from a field which has the same length limitation. As such, an - // error here is likely to only be caused during development. - // - // NOTE: another reasonable approach here might be to return a - // randomized slice if we encounter an error, which would break the - // connection, but avoid panicking. This would perhaps be safer but - // significantly more confusing to users. - panic(fmt.Errorf("failed to construct HKDF label: %s", err)) - } - out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) - if err != nil || n != length { - panic("tls: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} - -// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - if transcript == nil { - transcript = c.hash.New() - } - return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) -} - -// extract implements HKDF-Extract with the cipher suite hash. -func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { - if newSecret == nil { - newSecret = make([]byte, c.hash.Size()) - } - return hkdf.Extract(c.hash.New, newSecret, currentSecret) -} - -// nextTrafficSecret generates the next traffic secret, given the current one, -// according to RFC 8446, Section 7.2. -func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { - return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) -} - -// trafficKey generates traffic keys according to RFC 8446, Section 7.3. -func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { - key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) - iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) - return -} - -// finishedHash generates the Finished verify_data or PskBinderEntry according -// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey -// selection. -func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { - finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) - verifyData := hmac.New(c.hash.New, finishedKey) - verifyData.Write(transcript.Sum(nil)) - return verifyData.Sum(nil) -} - -// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to -// RFC 8446, Section 7.5. -func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { - expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) - return func(label string, context []byte, length int) ([]byte, error) { - secret := c.deriveSecret(expMasterSecret, label, nil) - h := c.hash.New() - h.Write(context) - return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil - } -} - -// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman -// according to RFC 8446, Section 4.2.8.2. -func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { - curve, ok := curveForCurveID(curveID) - if !ok { - return nil, errors.New("tls: internal error: unsupported curve") - } - - return curve.GenerateKey(rand) -} - -func curveForCurveID(id CurveID) (ecdh.Curve, bool) { - switch id { - case X25519: - return ecdh.X25519(), true - case CurveP256: - return ecdh.P256(), true - case CurveP384: - return ecdh.P384(), true - case CurveP521: - return ecdh.P521(), true - default: - return nil, false - } -} - -func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) { - switch curve { - case ecdh.X25519(): - return X25519, true - case ecdh.P256(): - return CurveP256, true - case ecdh.P384(): - return CurveP384, true - case ecdh.P521(): - return CurveP521, true - default: - return 0, false - } -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/notboring.go b/vendor/github.com/quic-go/qtls-go1-20/notboring.go deleted file mode 100644 index f292e4f02..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/notboring.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2022 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -func needFIPS() bool { return false } - -func supportedSignatureAlgorithms() []SignatureScheme { - return defaultSupportedSignatureAlgorithms -} - -func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") } -func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") } -func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") } -func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") } - -var fipsSupportedSignatureAlgorithms []SignatureScheme diff --git a/vendor/github.com/quic-go/qtls-go1-20/prf.go b/vendor/github.com/quic-go/qtls-go1-20/prf.go deleted file mode 100644 index 147128918..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/prf.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "errors" - "fmt" - "hash" -) - -// Split a premaster secret in two as specified in RFC 4346, Section 5. -func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { - s1 = secret[0 : (len(secret)+1)/2] - s2 = secret[len(secret)/2:] - return -} - -// pHash implements the P_hash function, as defined in RFC 4346, Section 5. -func pHash(result, secret, seed []byte, hash func() hash.Hash) { - h := hmac.New(hash, secret) - h.Write(seed) - a := h.Sum(nil) - - j := 0 - for j < len(result) { - h.Reset() - h.Write(a) - h.Write(seed) - b := h.Sum(nil) - copy(result[j:], b) - j += len(b) - - h.Reset() - h.Write(a) - a = h.Sum(nil) - } -} - -// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. -func prf10(result, secret, label, seed []byte) { - hashSHA1 := sha1.New - hashMD5 := md5.New - - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - s1, s2 := splitPreMasterSecret(secret) - pHash(result, s1, labelAndSeed, hashMD5) - result2 := make([]byte, len(result)) - pHash(result2, s2, labelAndSeed, hashSHA1) - - for i, b := range result2 { - result[i] ^= b - } -} - -// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. -func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { - return func(result, secret, label, seed []byte) { - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - pHash(result, secret, labelAndSeed, hashFunc) - } -} - -const ( - masterSecretLength = 48 // Length of a master secret in TLS 1.1. - finishedVerifyLength = 12 // Length of verify_data in a Finished message. -) - -var masterSecretLabel = []byte("master secret") -var keyExpansionLabel = []byte("key expansion") -var clientFinishedLabel = []byte("client finished") -var serverFinishedLabel = []byte("server finished") - -func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { - switch version { - case VersionTLS10, VersionTLS11: - return prf10, crypto.Hash(0) - case VersionTLS12: - if suite.flags&suiteSHA384 != 0 { - return prf12(sha512.New384), crypto.SHA384 - } - return prf12(sha256.New), crypto.SHA256 - default: - panic("unknown version") - } -} - -func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { - prf, _ := prfAndHashForVersion(version, suite) - return prf -} - -// masterFromPreMasterSecret generates the master secret from the pre-master -// secret. See RFC 5246, Section 8.1. -func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { - seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - masterSecret := make([]byte, masterSecretLength) - prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) - return masterSecret -} - -// keysFromMasterSecret generates the connection keys from the master -// secret, given the lengths of the MAC key, cipher key and IV, as defined in -// RFC 2246, Section 6.3. -func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { - seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) - seed = append(seed, serverRandom...) - seed = append(seed, clientRandom...) - - n := 2*macLen + 2*keyLen + 2*ivLen - keyMaterial := make([]byte, n) - prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) - clientMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - serverMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - clientKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - serverKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - clientIV = keyMaterial[:ivLen] - keyMaterial = keyMaterial[ivLen:] - serverIV = keyMaterial[:ivLen] - return -} - -func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { - var buffer []byte - if version >= VersionTLS12 { - buffer = []byte{} - } - - prf, hash := prfAndHashForVersion(version, cipherSuite) - if hash != 0 { - return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} - } - - return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} -} - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf func(result, secret, label, seed []byte) -} - -func (h *finishedHash) Write(msg []byte) (n int, err error) { - h.client.Write(msg) - h.server.Write(msg) - - if h.version < VersionTLS12 { - h.clientMD5.Write(msg) - h.serverMD5.Write(msg) - } - - if h.buffer != nil { - h.buffer = append(h.buffer, msg...) - } - - return len(msg), nil -} - -func (h finishedHash) Sum() []byte { - if h.version >= VersionTLS12 { - return h.client.Sum(nil) - } - - out := make([]byte, 0, md5.Size+sha1.Size) - out = h.clientMD5.Sum(out) - return h.client.Sum(out) -} - -// clientSum returns the contents of the verify_data member of a client's -// Finished message. -func (h finishedHash) clientSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) - return out -} - -// serverSum returns the contents of the verify_data member of a server's -// Finished message. -func (h finishedHash) serverSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) - return out -} - -// hashForClientCertificate returns the handshake messages so far, pre-hashed if -// necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { - if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { - panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") - } - - if sigType == signatureEd25519 { - return h.buffer - } - - if h.version >= VersionTLS12 { - hash := hashAlg.New() - hash.Write(h.buffer) - return hash.Sum(nil) - } - - if sigType == signatureECDSA { - return h.server.Sum(nil) - } - - return h.Sum() -} - -// discardHandshakeBuffer is called when there is no more need to -// buffer the entirety of the handshake messages. -func (h *finishedHash) discardHandshakeBuffer() { - h.buffer = nil -} - -// noExportedKeyingMaterial is used as a value of -// ConnectionState.ekm when renegotiation is enabled and thus -// we wish to fail all key-material export requests. -func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { - return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") -} - -// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. -func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { - return func(label string, context []byte, length int) ([]byte, error) { - switch label { - case "client finished", "server finished", "master secret", "key expansion": - // These values are reserved and may not be used. - return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) - } - - seedLen := len(serverRandom) + len(clientRandom) - if context != nil { - seedLen += 2 + len(context) - } - seed := make([]byte, 0, seedLen) - - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - if context != nil { - if len(context) >= 1<<16 { - return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") - } - seed = append(seed, byte(len(context)>>8), byte(len(context))) - seed = append(seed, context...) - } - - keyMaterial := make([]byte, length) - prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) - return keyMaterial, nil - } -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/quic.go b/vendor/github.com/quic-go/qtls-go1-20/quic.go deleted file mode 100644 index 6be1e9196..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/quic.go +++ /dev/null @@ -1,412 +0,0 @@ -// Copyright 2023 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "context" - "errors" - "fmt" -) - -// QUICEncryptionLevel represents a QUIC encryption level used to transmit -// handshake messages. -type QUICEncryptionLevel int - -const ( - QUICEncryptionLevelInitial = QUICEncryptionLevel(iota) - QUICEncryptionLevelEarly - QUICEncryptionLevelHandshake - QUICEncryptionLevelApplication -) - -func (l QUICEncryptionLevel) String() string { - switch l { - case QUICEncryptionLevelInitial: - return "Initial" - case QUICEncryptionLevelEarly: - return "Early" - case QUICEncryptionLevelHandshake: - return "Handshake" - case QUICEncryptionLevelApplication: - return "Application" - default: - return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l)) - } -} - -// A QUICConn represents a connection which uses a QUIC implementation as the underlying -// transport as described in RFC 9001. -// -// Methods of QUICConn are not safe for concurrent use. -type QUICConn struct { - conn *Conn - - sessionTicketSent bool -} - -// A QUICConfig configures a QUICConn. -type QUICConfig struct { - TLSConfig *Config - ExtraConfig *ExtraConfig -} - -// A QUICEventKind is a type of operation on a QUIC connection. -type QUICEventKind int - -const ( - // QUICNoEvent indicates that there are no events available. - QUICNoEvent QUICEventKind = iota - - // QUICSetReadSecret and QUICSetWriteSecret provide the read and write - // secrets for a given encryption level. - // QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set. - // - // Secrets for the Initial encryption level are derived from the initial - // destination connection ID, and are not provided by the QUICConn. - QUICSetReadSecret - QUICSetWriteSecret - - // QUICWriteData provides data to send to the peer in CRYPTO frames. - // QUICEvent.Data is set. - QUICWriteData - - // QUICTransportParameters provides the peer's QUIC transport parameters. - // QUICEvent.Data is set. - QUICTransportParameters - - // QUICTransportParametersRequired indicates that the caller must provide - // QUIC transport parameters to send to the peer. The caller should set - // the transport parameters with QUICConn.SetTransportParameters and call - // QUICConn.NextEvent again. - // - // If transport parameters are set before calling QUICConn.Start, the - // connection will never generate a QUICTransportParametersRequired event. - QUICTransportParametersRequired - - // QUICRejectedEarlyData indicates that the server rejected 0-RTT data even - // if we offered it. It's returned before QUICEncryptionLevelApplication - // keys are returned. - QUICRejectedEarlyData - - // QUICHandshakeDone indicates that the TLS handshake has completed. - QUICHandshakeDone -) - -// A QUICEvent is an event occurring on a QUIC connection. -// -// The type of event is specified by the Kind field. -// The contents of the other fields are kind-specific. -type QUICEvent struct { - Kind QUICEventKind - - // Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData. - Level QUICEncryptionLevel - - // Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData. - // The contents are owned by crypto/tls, and are valid until the next NextEvent call. - Data []byte - - // Set for QUICSetReadSecret and QUICSetWriteSecret. - Suite uint16 -} - -type quicState struct { - events []QUICEvent - nextEvent int - - // eventArr is a statically allocated event array, large enough to handle - // the usual maximum number of events resulting from a single call: transport - // parameters, Initial data, Early read secret, Handshake write and read - // secrets, Handshake data, Application write secret, Application data. - eventArr [8]QUICEvent - - started bool - signalc chan struct{} // handshake data is available to be read - blockedc chan struct{} // handshake is waiting for data, closed when done - cancelc <-chan struct{} // handshake has been canceled - cancel context.CancelFunc - - // readbuf is shared between HandleData and the handshake goroutine. - // HandshakeCryptoData passes ownership to the handshake goroutine by - // reading from signalc, and reclaims ownership by reading from blockedc. - readbuf []byte - - transportParams []byte // to send to the peer -} - -// QUICClient returns a new TLS client side connection using QUICTransport as the -// underlying transport. The config cannot be nil. -// -// The config's MinVersion must be at least TLS 1.3. -func QUICClient(config *QUICConfig) *QUICConn { - return newQUICConn(Client(nil, config.TLSConfig), config.ExtraConfig) -} - -// QUICServer returns a new TLS server side connection using QUICTransport as the -// underlying transport. The config cannot be nil. -// -// The config's MinVersion must be at least TLS 1.3. -func QUICServer(config *QUICConfig) *QUICConn { - return newQUICConn(Server(nil, config.TLSConfig), config.ExtraConfig) -} - -func newQUICConn(conn *Conn, extraConfig *ExtraConfig) *QUICConn { - conn.quic = &quicState{ - signalc: make(chan struct{}), - blockedc: make(chan struct{}), - } - conn.quic.events = conn.quic.eventArr[:0] - conn.extraConfig = extraConfig - return &QUICConn{ - conn: conn, - } -} - -// Start starts the client or server handshake protocol. -// It may produce connection events, which may be read with NextEvent. -// -// Start must be called at most once. -func (q *QUICConn) Start(ctx context.Context) error { - if q.conn.quic.started { - return quicError(errors.New("tls: Start called more than once")) - } - q.conn.quic.started = true - if q.conn.config.MinVersion < VersionTLS13 { - return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13")) - } - go q.conn.HandshakeContext(ctx) - if _, ok := <-q.conn.quic.blockedc; !ok { - return q.conn.handshakeErr - } - return nil -} - -// NextEvent returns the next event occurring on the connection. -// It returns an event with a Kind of QUICNoEvent when no events are available. -func (q *QUICConn) NextEvent() QUICEvent { - qs := q.conn.quic - if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 { - // Write over some of the previous event's data, - // to catch callers erroniously retaining it. - qs.events[last].Data[0] = 0 - } - if qs.nextEvent >= len(qs.events) { - qs.events = qs.events[:0] - qs.nextEvent = 0 - return QUICEvent{Kind: QUICNoEvent} - } - e := qs.events[qs.nextEvent] - qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data - qs.nextEvent++ - return e -} - -// Close closes the connection and stops any in-progress handshake. -func (q *QUICConn) Close() error { - if q.conn.quic.cancel == nil { - return nil // never started - } - q.conn.quic.cancel() - for range q.conn.quic.blockedc { - // Wait for the handshake goroutine to return. - } - return q.conn.handshakeErr -} - -// HandleData handles handshake bytes received from the peer. -// It may produce connection events, which may be read with NextEvent. -func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error { - c := q.conn - if c.in.level != level { - return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level"))) - } - c.quic.readbuf = data - <-c.quic.signalc - _, ok := <-c.quic.blockedc - if ok { - // The handshake goroutine is waiting for more data. - return nil - } - // The handshake goroutine has exited. - c.hand.Write(c.quic.readbuf) - c.quic.readbuf = nil - for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil { - b := q.conn.hand.Bytes() - n := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) - if 4+n < len(b) { - return nil - } - if err := q.conn.handlePostHandshakeMessage(); err != nil { - return quicError(err) - } - } - if q.conn.handshakeErr != nil { - return quicError(q.conn.handshakeErr) - } - return nil -} - -// SendSessionTicket sends a session ticket to the client. -// It produces connection events, which may be read with NextEvent. -// Currently, it can only be called once. -func (q *QUICConn) SendSessionTicket(earlyData bool) error { - c := q.conn - if !c.isHandshakeComplete.Load() { - return quicError(errors.New("tls: SendSessionTicket called before handshake completed")) - } - if c.isClient { - return quicError(errors.New("tls: SendSessionTicket called on the client")) - } - if q.sessionTicketSent { - return quicError(errors.New("tls: SendSessionTicket called multiple times")) - } - q.sessionTicketSent = true - return quicError(c.sendSessionTicket(earlyData)) -} - -// ConnectionState returns basic TLS details about the connection. -func (q *QUICConn) ConnectionState() ConnectionState { - return q.conn.ConnectionState() -} - -// SetTransportParameters sets the transport parameters to send to the peer. -// -// Server connections may delay setting the transport parameters until after -// receiving the client's transport parameters. See QUICTransportParametersRequired. -func (q *QUICConn) SetTransportParameters(params []byte) { - if params == nil { - params = []byte{} - } - q.conn.quic.transportParams = params - if q.conn.quic.started { - <-q.conn.quic.signalc - <-q.conn.quic.blockedc - } -} - -// quicError ensures err is an AlertError. -// If err is not already, quicError wraps it with alertInternalError. -func quicError(err error) error { - if err == nil { - return nil - } - var ae AlertError - if errors.As(err, &ae) { - return err - } - var a alert - if !errors.As(err, &a) { - a = alertInternalError - } - // Return an error wrapping the original error and an AlertError. - // Truncate the text of the alert to 0 characters. - return fmt.Errorf("%w%.0w", err, AlertError(a)) -} - -func (c *Conn) quicReadHandshakeBytes(n int) error { - for c.hand.Len() < n { - if err := c.quicWaitForSignal(); err != nil { - return err - } - } - return nil -} - -func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICSetReadSecret, - Level: level, - Suite: suite, - Data: secret, - }) -} - -func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICSetWriteSecret, - Level: level, - Suite: suite, - Data: secret, - }) -} - -func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) { - var last *QUICEvent - if len(c.quic.events) > 0 { - last = &c.quic.events[len(c.quic.events)-1] - } - if last == nil || last.Kind != QUICWriteData || last.Level != level { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICWriteData, - Level: level, - }) - last = &c.quic.events[len(c.quic.events)-1] - } - last.Data = append(last.Data, data...) -} - -func (c *Conn) quicSetTransportParameters(params []byte) { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICTransportParameters, - Data: params, - }) -} - -func (c *Conn) quicGetTransportParameters() ([]byte, error) { - if c.quic.transportParams == nil { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICTransportParametersRequired, - }) - } - for c.quic.transportParams == nil { - if err := c.quicWaitForSignal(); err != nil { - return nil, err - } - } - return c.quic.transportParams, nil -} - -func (c *Conn) quicHandshakeComplete() { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICHandshakeDone, - }) -} - -func (c *Conn) quicRejectedEarlyData() { - c.quic.events = append(c.quic.events, QUICEvent{ - Kind: QUICRejectedEarlyData, - }) -} - -// quicWaitForSignal notifies the QUICConn that handshake progress is blocked, -// and waits for a signal that the handshake should proceed. -// -// The handshake may become blocked waiting for handshake bytes -// or for the user to provide transport parameters. -func (c *Conn) quicWaitForSignal() error { - // Drop the handshake mutex while blocked to allow the user - // to call ConnectionState before the handshake completes. - c.handshakeMutex.Unlock() - defer c.handshakeMutex.Lock() - // Send on blockedc to notify the QUICConn that the handshake is blocked. - // Exported methods of QUICConn wait for the handshake to become blocked - // before returning to the user. - select { - case c.quic.blockedc <- struct{}{}: - case <-c.quic.cancelc: - return c.sendAlertLocked(alertCloseNotify) - } - // The QUICConn reads from signalc to notify us that the handshake may - // be able to proceed. (The QUICConn reads, because we close signalc to - // indicate that the handshake has completed.) - select { - case c.quic.signalc <- struct{}{}: - c.hand.Write(c.quic.readbuf) - c.quic.readbuf = nil - case <-c.quic.cancelc: - return c.sendAlertLocked(alertCloseNotify) - } - return nil -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/ticket.go b/vendor/github.com/quic-go/qtls-go1-20/ticket.go deleted file mode 100644 index 366620707..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/ticket.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "errors" - "golang.org/x/crypto/cryptobyte" - "io" -) - -// sessionState contains the information that is serialized into a session -// ticket in order to later resume a connection. -type sessionState struct { - vers uint16 - cipherSuite uint16 - createdAt uint64 - masterSecret []byte // opaque master_secret<1..2^16-1>; - // struct { opaque certificate<1..2^24-1> } Certificate; - certificates [][]byte // Certificate certificate_list<0..2^24-1>; - - // usedOldKey is true if the ticket from which this session came from - // was encrypted with an older key and thus should be refreshed. - usedOldKey bool -} - -func (m *sessionState) marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(m.vers) - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.masterSecret) - }) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for _, cert := range m.certificates { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - } - }) - return b.Bytes() -} - -func (m *sessionState) unmarshal(data []byte) bool { - *m = sessionState{usedOldKey: m.usedOldKey} - s := cryptobyte.String(data) - if ok := s.ReadUint16(&m.vers) && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint16LengthPrefixed(&s, &m.masterSecret) && - len(m.masterSecret) != 0; !ok { - return false - } - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - if !readUint24LengthPrefixed(&certList, &cert) { - return false - } - m.certificates = append(m.certificates, cert) - } - return s.Empty() -} - -// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first -// version (revision = 0) doesn't carry any of the information needed for 0-RTT -// validation and the nonce is always empty. -// version (revision = 1) carries the max_early_data_size sent in the ticket. -// version (revision = 2) carries the ALPN sent in the ticket. -type sessionStateTLS13 struct { - // uint8 version = 0x0304; - // uint8 revision = 2; - cipherSuite uint16 - createdAt uint64 - resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; - certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; - maxEarlyData uint32 - alpn string - - appData []byte -} - -func (m *sessionStateTLS13) marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(VersionTLS13) - b.AddUint8(2) // revision - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.resumptionSecret) - }) - marshalCertificate(&b, m.certificate) - b.AddUint32(m.maxEarlyData) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpn)) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.appData) - }) - return b.Bytes() -} - -func (m *sessionStateTLS13) unmarshal(data []byte) bool { - *m = sessionStateTLS13{} - s := cryptobyte.String(data) - var version uint16 - var revision uint8 - var alpn []byte - ret := s.ReadUint16(&version) && - version == VersionTLS13 && - s.ReadUint8(&revision) && - revision == 2 && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint8LengthPrefixed(&s, &m.resumptionSecret) && - len(m.resumptionSecret) != 0 && - unmarshalCertificate(&s, &m.certificate) && - s.ReadUint32(&m.maxEarlyData) && - readUint8LengthPrefixed(&s, &alpn) && - readUint16LengthPrefixed(&s, &m.appData) && - s.Empty() - m.alpn = string(alpn) - return ret -} - -func (c *Conn) encryptTicket(state []byte) ([]byte, error) { - if len(c.ticketKeys) == 0 { - return nil, errors.New("tls: internal error: session ticket keys unavailable") - } - - encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - - if _, err := io.ReadFull(c.config.rand(), iv); err != nil { - return nil, err - } - key := c.ticketKeys[0] - copy(keyName, key.keyName[:]) - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) - } - cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - mac.Sum(macBytes[:0]) - - return encrypted, nil -} - -func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { - if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { - return nil, false - } - - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] - - keyIndex := -1 - for i, candidateKey := range c.ticketKeys { - if bytes.Equal(keyName, candidateKey.keyName[:]) { - keyIndex = i - break - } - } - if keyIndex == -1 { - return nil, false - } - key := &c.ticketKeys[keyIndex] - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - expected := mac.Sum(nil) - - if subtle.ConstantTimeCompare(macBytes, expected) != 1 { - return nil, false - } - - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, false - } - plaintext = make([]byte, len(ciphertext)) - cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) - - return plaintext, keyIndex > 0 -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/tls.go b/vendor/github.com/quic-go/qtls-go1-20/tls.go deleted file mode 100644 index 47eed0855..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/tls.go +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// package qtls partially implements TLS 1.2, as specified in RFC 5246, -// and TLS 1.3, as specified in RFC 8446. -package qtls - -// BUG(agl): The crypto/tls package only implements some countermeasures -// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 -// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "net" - "os" - "strings" -) - -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: fromConfig(config), - } - c.handshakeFn = c.serverHandshake - return c -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func Client(conn net.Conn, config *Config) *Conn { - c := &Conn{ - conn: conn, - config: fromConfig(config), - isClient: true, - } - c.handshakeFn = c.clientHandshake - return c -} - -// A listener implements a network listener (net.Listener) for TLS connections. -type listener struct { - net.Listener - config *Config -} - -// Accept waits for and returns the next incoming TLS connection. -// The returned connection is of type *Conn. -func (l *listener) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return Server(c, l.config), nil -} - -// NewListener creates a Listener which accepts connections from an inner -// Listener and wraps each connection with Server. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func NewListener(inner net.Listener, config *Config) net.Listener { - l := new(listener) - l.Listener = inner - l.config = config - return l -} - -// Listen creates a TLS listener accepting connections on the -// given network address using net.Listen. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Listen(network, laddr string, config *Config) (net.Listener, error) { - if config == nil || len(config.Certificates) == 0 && - config.GetCertificate == nil && config.GetConfigForClient == nil { - return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") - } - l, err := net.Listen(network, laddr) - if err != nil { - return nil, err - } - return NewListener(l, config), nil -} - -type timeoutError struct{} - -func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - -// DialWithDialer connects to the given network address using dialer.Dial and -// then initiates a TLS handshake, returning the resulting TLS connection. Any -// timeout or deadline given in the dialer apply to connection and TLS -// handshake as a whole. -// -// DialWithDialer interprets a nil configuration as equivalent to the zero -// configuration; see the documentation of Config for the defaults. -// -// DialWithDialer uses context.Background internally; to specify the context, -// use Dialer.DialContext with NetDialer set to the desired dialer. -func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - return dial(context.Background(), dialer, network, addr, config) -} - -func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { - if netDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) - defer cancel() - } - - if !netDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) - defer cancel() - } - - rawConn, err := netDialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - if config == nil { - config = defaultConfig() - } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - if config.ServerName == "" { - // Make a copy to avoid polluting argument or default. - c := config.Clone() - c.ServerName = hostname - config = c - } - - conn := Client(rawConn, config) - if err := conn.HandshakeContext(ctx); err != nil { - rawConn.Close() - return nil, err - } - return conn, nil -} - -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config) (*Conn, error) { - return DialWithDialer(new(net.Dialer), network, addr, config) -} - -// Dialer dials TLS connections given a configuration and a Dialer for the -// underlying connection. -type Dialer struct { - // NetDialer is the optional dialer to use for the TLS connections' - // underlying TCP connections. - // A nil NetDialer is equivalent to the net.Dialer zero value. - NetDialer *net.Dialer - - // Config is the TLS configuration to use for new connections. - // A nil configuration is equivalent to the zero - // configuration; see the documentation of Config for the - // defaults. - Config *Config -} - -// Dial connects to the given network address and initiates a TLS -// handshake, returning the resulting TLS connection. -// -// The returned Conn, if any, will always be of type *Conn. -// -// Dial uses context.Background internally; to specify the context, -// use DialContext. -func (d *Dialer) Dial(network, addr string) (net.Conn, error) { - return d.DialContext(context.Background(), network, addr) -} - -func (d *Dialer) netDialer() *net.Dialer { - if d.NetDialer != nil { - return d.NetDialer - } - return new(net.Dialer) -} - -// DialContext connects to the given network address and initiates a TLS -// handshake, returning the resulting TLS connection. -// -// The provided Context must be non-nil. If the context expires before -// the connection is complete, an error is returned. Once successfully -// connected, any expiration of the context will not affect the -// connection. -// -// The returned Conn, if any, will always be of type *Conn. -func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := dial(ctx, d.netDialer(), network, addr, d.Config) - if err != nil { - // Don't return c (a typed nil) in an interface. - return nil, err - } - return c, nil -} - -// LoadX509KeyPair reads and parses a public/private key pair from a pair -// of files. The files must contain PEM encoded data. The certificate file -// may contain intermediate certificates following the leaf certificate to -// form a certificate chain. On successful return, Certificate.Leaf will -// be nil because the parsed form of the certificate is not retained. -func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { - certPEMBlock, err := os.ReadFile(certFile) - if err != nil { - return Certificate{}, err - } - keyPEMBlock, err := os.ReadFile(keyFile) - if err != nil { - return Certificate{}, err - } - return X509KeyPair(certPEMBlock, keyPEMBlock) -} - -// X509KeyPair parses a public/private key pair from a pair of -// PEM encoded data. On successful return, Certificate.Leaf will be nil because -// the parsed form of the certificate is not retained. -func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { - fail := func(err error) (Certificate, error) { return Certificate{}, err } - - var cert Certificate - var skippedBlockTypes []string - for { - var certDERBlock *pem.Block - certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) - if certDERBlock == nil { - break - } - if certDERBlock.Type == "CERTIFICATE" { - cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) - } else { - skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) - } - } - - if len(cert.Certificate) == 0 { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in certificate input")) - } - if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { - return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) - } - return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - - skippedBlockTypes = skippedBlockTypes[:0] - var keyDERBlock *pem.Block - for { - keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) - if keyDERBlock == nil { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in key input")) - } - if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { - return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) - } - return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { - break - } - skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) - } - - // We don't need to parse the public key for TLS, but we so do anyway - // to check that it looks sane and matches the private key. - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - return fail(err) - } - - cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) - if err != nil { - return fail(err) - } - - switch pub := x509Cert.PublicKey.(type) { - case *rsa.PublicKey: - priv, ok := cert.PrivateKey.(*rsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.N.Cmp(priv.N) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case *ecdsa.PublicKey: - priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case ed25519.PublicKey: - priv, ok := cert.PrivateKey.(ed25519.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { - return fail(errors.New("tls: private key does not match public key")) - } - default: - return fail(errors.New("tls: unknown public key algorithm")) - } - - return cert, nil -} - -// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates -// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. -// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. -func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { - if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { - return key, nil - } - if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { - switch key := key.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: - return key, nil - default: - return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") - } - } - if key, err := x509.ParseECPrivateKey(der); err == nil { - return key, nil - } - - return nil, errors.New("tls: failed to parse private key") -} diff --git a/vendor/github.com/quic-go/qtls-go1-20/unsafe.go b/vendor/github.com/quic-go/qtls-go1-20/unsafe.go deleted file mode 100644 index 67a75677c..000000000 --- a/vendor/github.com/quic-go/qtls-go1-20/unsafe.go +++ /dev/null @@ -1,101 +0,0 @@ -package qtls - -import ( - "crypto/tls" - "reflect" - "unsafe" -) - -func init() { - if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { - panic("qtls.ConnectionState doesn't match") - } - if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { - panic("qtls.ClientSessionState doesn't match") - } - if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { - panic("qtls.CertificateRequestInfo doesn't match") - } - if !structsEqual(&tls.Config{}, &config{}) { - panic("qtls.Config doesn't match") - } - if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { - panic("qtls.ClientHelloInfo doesn't match") - } -} - -func toConnectionState(c connectionState) ConnectionState { - return *(*ConnectionState)(unsafe.Pointer(&c)) -} - -func toClientSessionState(s *clientSessionState) *ClientSessionState { - return (*ClientSessionState)(unsafe.Pointer(s)) -} - -func fromClientSessionState(s *ClientSessionState) *clientSessionState { - return (*clientSessionState)(unsafe.Pointer(s)) -} - -func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { - return (*CertificateRequestInfo)(unsafe.Pointer(i)) -} - -func toConfig(c *config) *Config { - return (*Config)(unsafe.Pointer(c)) -} - -func fromConfig(c *Config) *config { - return (*config)(unsafe.Pointer(c)) -} - -func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { - return (*ClientHelloInfo)(unsafe.Pointer(chi)) -} - -func structsEqual(a, b interface{}) bool { - return compare(reflect.ValueOf(a), reflect.ValueOf(b)) -} - -func compare(a, b reflect.Value) bool { - sa := a.Elem() - sb := b.Elem() - if sa.NumField() != sb.NumField() { - return false - } - for i := 0; i < sa.NumField(); i++ { - fa := sa.Type().Field(i) - fb := sb.Type().Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - if fa.Type.Kind() != fb.Type.Kind() { - return false - } - if fa.Type.Kind() == reflect.Slice { - if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { - return false - } - continue - } - return false - } - } - return true -} - -func compareStruct(a, b reflect.Type) bool { - if a.NumField() != b.NumField() { - return false - } - for i := 0; i < a.NumField(); i++ { - fa := a.Field(i) - fb := b.Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - return false - } - } - return true -} - -// InitSessionTicketKeys triggers the initialization of session ticket keys. -func InitSessionTicketKeys(conf *Config) { - fromConfig(conf).ticketKeys(nil) -} diff --git a/vendor/github.com/quic-go/quic-go/.golangci.yml b/vendor/github.com/quic-go/quic-go/.golangci.yml index 1315759bc..469d54cfb 100644 --- a/vendor/github.com/quic-go/quic-go/.golangci.yml +++ b/vendor/github.com/quic-go/quic-go/.golangci.yml @@ -2,16 +2,6 @@ run: skip-files: - internal/handshake/cipher_suite.go linters-settings: - depguard: - type: blacklist - packages: - - github.com/marten-seemann/qtls - - github.com/quic-go/qtls-go1-19 - - github.com/quic-go/qtls-go1-20 - packages-with-error-message: - - github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls" - - github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls" - - github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls" misspell: ignore-words: - ect @@ -20,7 +10,6 @@ linters: disable-all: true enable: - asciicheck - - depguard - exhaustive - exportloopref - goimports diff --git a/vendor/github.com/quic-go/quic-go/README.md b/vendor/github.com/quic-go/quic-go/README.md index a793f9847..faba82f3d 100644 --- a/vendor/github.com/quic-go/quic-go/README.md +++ b/vendor/github.com/quic-go/quic-go/README.md @@ -12,6 +12,9 @@ In addition to these base RFCs, it also implements the following RFCs: * Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) * Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)) * QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369)) +* QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/)) + +Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go). ## Using QUIC @@ -121,7 +124,7 @@ In case the application wishes to abort sending on a `quic.SendStream` or a `qui Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream. -A bidirectional stream is only closed once both the read and the write side of the stream have been either closed and reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`. +A bidirectional stream is only closed once both the read and the write side of the stream have been either closed or reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`. ### Configuring QUIC @@ -158,23 +161,41 @@ On the receiver side, this is surfaced as a `quic.ApplicationError`. Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`. -QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not automatically retransmitted. +QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not retransmitted. -Datagrams are sent using the `SendMessage` method on the `quic.Connection`: +Datagrams are sent using the `SendDatagram` method on the `quic.Connection`: ```go -conn.SendMessage([]byte("foobar")) +conn.SendDatagram([]byte("foobar")) ``` -And received using `ReceiveMessage`: +And received using `ReceiveDatagram`: ```go -msg, err := conn.ReceiveMessage() +msg, err := conn.ReceiveDatagram() ``` Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599). +### QUIC Event Logging using qlog + +quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/), providing comprehensive insights in the internals of a QUIC connection. + +qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures. +qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection. +`qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set. +The default qlog tracer can be used like this: +```go +quic.Config{ + Tracer: qlog.DefaultTracer, +} +``` + +This example creates a new qlog file under `/_.qlog`, e.g. `qlogs/2e0407da_client.qlog`. + + +For custom qlog behavior, `qlog.NewConnectionTracer` can be used. ## Using HTTP/3 @@ -200,14 +221,18 @@ http.Client{ ## Projects using quic-go | Project | Description | Stars | -|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------| +| ---------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- | | [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | -| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | +| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet | ![GitHub Repo stars](https://img.shields.io/github/stars/fatedier/frp?style=flat-square) | +| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | +| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) | +| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | +| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins | ![GitHub Repo stars](https://img.shields.io/github/stars/roadrunner-server/roadrunner?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | | [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | @@ -219,11 +244,6 @@ If you'd like to see your project added to this list, please send us a PR. quic-go always aims to support the latest two Go releases. -### Dependency on forked crypto/tls - -Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20). -This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward. - ## Contributing We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. diff --git a/vendor/github.com/quic-go/quic-go/client.go b/vendor/github.com/quic-go/quic-go/client.go index 61cd75265..70dd5e19c 100644 --- a/vendor/github.com/quic-go/quic-go/client.go +++ b/vendor/github.com/quic-go/quic-go/client.go @@ -28,13 +28,13 @@ type client struct { initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool - version protocol.VersionNumber + version protocol.Version handshakeChan chan struct{} conn quicConn - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer tracingID uint64 logger utils.Logger } @@ -153,7 +153,7 @@ func dial( if c.config.Tracer != nil { c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } - if c.tracer != nil { + if c.tracer != nil && c.tracer.StartedConnection != nil { c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { @@ -232,8 +232,8 @@ func (c *client) dial(ctx context.Context) error { select { case <-ctx.Done(): - c.conn.shutdown() - return ctx.Err() + c.conn.destroy(nil) + return context.Cause(ctx) case err := <-errorChan: return err case recreateErr := <-recreateChan: diff --git a/vendor/github.com/quic-go/quic-go/closed_conn.go b/vendor/github.com/quic-go/quic-go/closed_conn.go index 0c988b53e..833385327 100644 --- a/vendor/github.com/quic-go/quic-go/closed_conn.go +++ b/vendor/github.com/quic-go/quic-go/closed_conn.go @@ -4,7 +4,6 @@ import ( "math/bits" "net" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" ) @@ -12,9 +11,8 @@ import ( // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. type closedLocalConn struct { - counter uint32 - perspective protocol.Perspective - logger utils.Logger + counter uint32 + logger utils.Logger sendPacket func(net.Addr, packetInfo) } @@ -22,11 +20,10 @@ type closedLocalConn struct { var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { +func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler { return &closedLocalConn{ - sendPacket: sendPacket, - perspective: pers, - logger: logger, + sendPacket: sendPacket, + logger: logger, } } @@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) { c.sendPacket(p.remoteAddr, p.info) } -func (c *closedLocalConn) shutdown() {} -func (c *closedLocalConn) destroy(error) {} -func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective } +func (c *closedLocalConn) destroy(error) {} +func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {} // A closedRemoteConn is a connection that was closed remotely. // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // We can just ignore those packets. -type closedRemoteConn struct { - perspective protocol.Perspective -} +type closedRemoteConn struct{} var _ packetHandler = &closedRemoteConn{} -func newClosedRemoteConn(pers protocol.Perspective) packetHandler { - return &closedRemoteConn{perspective: pers} +func newClosedRemoteConn() packetHandler { + return &closedRemoteConn{} } -func (s *closedRemoteConn) handlePacket(receivedPacket) {} -func (s *closedRemoteConn) shutdown() {} -func (s *closedRemoteConn) destroy(error) {} -func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } +func (c *closedRemoteConn) handlePacket(receivedPacket) {} +func (c *closedRemoteConn) destroy(error) {} +func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {} diff --git a/vendor/github.com/quic-go/quic-go/codecov.yml b/vendor/github.com/quic-go/quic-go/codecov.yml index 694351509..59e4b58f6 100644 --- a/vendor/github.com/quic-go/quic-go/codecov.yml +++ b/vendor/github.com/quic-go/quic-go/codecov.yml @@ -1,19 +1,12 @@ coverage: round: nearest ignore: - - streams_map_incoming_bidi.go - - streams_map_incoming_uni.go - - streams_map_outgoing_bidi.go - - streams_map_outgoing_uni.go - http3/gzip_reader.go - interop/ - - internal/ackhandler/packet_linkedlist.go - internal/handshake/cipher_suite.go - - internal/utils/byteinterval_linkedlist.go - - internal/utils/newconnectionid_linkedlist.go - - internal/utils/packetinterval_linkedlist.go - internal/utils/linkedlist/linkedlist.go - - logging/null_tracer.go + - internal/testdata + - testutils/ - fuzzing/ - metrics/ status: diff --git a/vendor/github.com/quic-go/quic-go/config.go b/vendor/github.com/quic-go/quic-go/config.go index 59df4cfd9..ee032e6ea 100644 --- a/vendor/github.com/quic-go/quic-go/config.go +++ b/vendor/github.com/quic-go/quic-go/config.go @@ -2,11 +2,9 @@ package quic import ( "fmt" - "net" "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" ) @@ -17,7 +15,11 @@ func (c *Config) Clone() *Config { } func (c *Config) handshakeTimeout() time.Duration { - return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) + return 2 * c.HandshakeIdleTimeout +} + +func (c *Config) maxRetryTokenAge() time.Duration { + return c.handshakeTimeout() } func validateConfig(config *Config) error { @@ -46,22 +48,6 @@ func validateConfig(config *Config) error { return nil } -// populateServerConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.MaxTokenAge == 0 { - config.MaxTokenAge = protocol.TokenValidity - } - if config.MaxRetryTokenAge == 0 { - config.MaxRetryTokenAge = protocol.RetryTokenValidity - } - if config.RequireAddressValidation == nil { - config.RequireAddressValidation = func(net.Addr) bool { return false } - } - return config -} - // populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateConfig(config *Config) *Config { @@ -110,26 +96,22 @@ func populateConfig(config *Config) *Config { } return &Config{ - GetConfigForClient: config.GetConfigForClient, - Versions: versions, - HandshakeIdleTimeout: handshakeIdleTimeout, - MaxIdleTimeout: idleTimeout, - MaxTokenAge: config.MaxTokenAge, - MaxRetryTokenAge: config.MaxRetryTokenAge, - RequireAddressValidation: config.RequireAddressValidation, - KeepAlivePeriod: config.KeepAlivePeriod, - InitialStreamReceiveWindow: initialStreamReceiveWindow, - MaxStreamReceiveWindow: maxStreamReceiveWindow, - InitialConnectionReceiveWindow: initialConnectionReceiveWindow, - MaxConnectionReceiveWindow: maxConnectionReceiveWindow, - AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, - MaxIncomingStreams: maxIncomingStreams, - MaxIncomingUniStreams: maxIncomingUniStreams, - TokenStore: config.TokenStore, - EnableDatagrams: config.EnableDatagrams, - DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, - DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, - Allow0RTT: config.Allow0RTT, - Tracer: config.Tracer, + GetConfigForClient: config.GetConfigForClient, + Versions: versions, + HandshakeIdleTimeout: handshakeIdleTimeout, + MaxIdleTimeout: idleTimeout, + KeepAlivePeriod: config.KeepAlivePeriod, + InitialStreamReceiveWindow: initialStreamReceiveWindow, + MaxStreamReceiveWindow: maxStreamReceiveWindow, + InitialConnectionReceiveWindow: initialConnectionReceiveWindow, + MaxConnectionReceiveWindow: maxConnectionReceiveWindow, + AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, + MaxIncomingStreams: maxIncomingStreams, + MaxIncomingUniStreams: maxIncomingUniStreams, + TokenStore: config.TokenStore, + EnableDatagrams: config.EnableDatagrams, + DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, + Allow0RTT: config.Allow0RTT, + Tracer: config.Tracer, } } diff --git a/vendor/github.com/quic-go/quic-go/conn_id_generator.go b/vendor/github.com/quic-go/quic-go/conn_id_generator.go index 2d28dc619..d7be6540e 100644 --- a/vendor/github.com/quic-go/quic-go/conn_id_generator.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_generator.go @@ -5,7 +5,6 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) @@ -20,7 +19,7 @@ type connIDGenerator struct { getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) + replaceWithClosed func([]protocol.ConnectionID, []byte) queueControlFrame func(wire.Frame) } @@ -31,7 +30,7 @@ func newConnIDGenerator( getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), + replaceWithClosed func([]protocol.ConnectionID, []byte), queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { @@ -60,7 +59,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { // transport parameter. // We currently don't send the preferred_address transport parameter, // so we can issue (limit - 1) connection IDs. - for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ { + for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ { if err := m.issueNewConnID(); err != nil { return err } @@ -127,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() { } } -func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { +func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, *m.initialClientDestConnID) @@ -135,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.replaceWithClosed(connIDs, pers, connClose) + m.replaceWithClosed(connIDs, connClose) } diff --git a/vendor/github.com/quic-go/quic-go/conn_id_manager.go b/vendor/github.com/quic-go/quic-go/conn_id_manager.go index ba65aec04..4aa3f749e 100644 --- a/vendor/github.com/quic-go/quic-go/conn_id_manager.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_manager.go @@ -145,7 +145,7 @@ func (h *connIDManager) updateConnectionID() { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, }) - h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber) + h.highestRetired = max(h.highestRetired, h.activeSequenceNumber) if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } diff --git a/vendor/github.com/quic-go/quic-go/connection.go b/vendor/github.com/quic-go/quic-go/connection.go index 877f2d033..f8bcd613c 100644 --- a/vendor/github.com/quic-go/quic-go/connection.go +++ b/vendor/github.com/quic-go/quic-go/connection.go @@ -25,7 +25,7 @@ import ( ) type unpacker interface { - UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) + UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } @@ -93,7 +93,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) + ReplaceWithClosed([]protocol.ConnectionID, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } @@ -106,7 +106,7 @@ type closeError struct { type errCloseForRecreating struct { nextPacketNumber protocol.PacketNumber - nextVersion protocol.VersionNumber + nextVersion protocol.Version } func (e *errCloseForRecreating) Error() string { @@ -128,7 +128,7 @@ type connection struct { srcConnIDLen int perspective protocol.Perspective - version protocol.VersionNumber + version protocol.Version config *Config conn sendConn @@ -177,6 +177,7 @@ type connection struct { earlyConnReadyChan chan struct{} sentFirstPacket bool + droppedInitialKeys bool handshakeComplete bool handshakeConfirmed bool @@ -208,7 +209,7 @@ type connection struct { connState ConnectionState logID string - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -232,10 +233,10 @@ var newConnection = func( tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, clientAddressValidated bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn { s := &connection{ conn: conn, @@ -243,7 +244,7 @@ var newConnection = func( handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, - oneRTTStream: newCryptoStream(true), + oneRTTStream: newCryptoStream(), perspective: protocol.PerspectiveServer, tracer: tracer, logger: logger, @@ -278,6 +279,7 @@ var newConnection = func( getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, clientAddressValidated, + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, @@ -306,11 +308,11 @@ var newConnection = func( RetrySourceConnectionID: retrySrcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupServer( @@ -344,10 +346,10 @@ var newClientConnection = func( initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn { s := &connection{ conn: conn, @@ -385,13 +387,14 @@ var newClientConnection = func( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, - false, /* has no effect */ + false, // has no effect + s.conn.capabilities().ECN, s.perspective, s.tracer, s.logger, ) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream(true) + oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -412,11 +415,11 @@ var newClientConnection = func( InitialSourceConnectionID: srcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupClient( @@ -447,11 +450,11 @@ var newClientConnection = func( } func (s *connection) preSetup() { - s.initialStream = newCryptoStream(false) - s.handshakeStream = newCryptoStream(false) + s.initialStream = newCryptoStream() + s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() - s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) + s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ByteCount(s.config.InitialConnectionReceiveWindow), @@ -518,6 +521,9 @@ func (s *connection) run() error { runLoop: for { + if s.framer.QueuedTooManyControlFrames() { + s.closeLocal(&qerr.TransportError{ErrorCode: InternalError}) + } // Close immediately if requested select { case closeErr = <-s.closeChan: @@ -627,7 +633,7 @@ runLoop: sendQueueAvailable = s.sendQueue.Available() continue } - if err := s.triggerSending(); err != nil { + if err := s.triggerSending(now); err != nil { s.closeLocal(err) } if s.sendQueue.WouldBlock() { @@ -640,8 +646,10 @@ runLoop: s.cryptoStreamHandler.Close() s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.handleCloseError(&closeErr) - if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { - s.tracer.Close() + if s.tracer != nil && s.tracer.Close != nil { + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) { + s.tracer.Close() + } } s.logger.Infof("Connection %s closed.", s.logID) s.timer.Stop() @@ -671,12 +679,13 @@ func (s *connection) ConnectionState() ConnectionState { cs := s.cryptoStreamHandler.ConnectionState() s.connState.TLS = cs.ConnectionState s.connState.Used0RTT = cs.Used0RTT + s.connState.GSO = s.conn.capabilities().GSO return s.connState } // Time when the connection should time out func (s *connection) nextIdleTimeoutTime() time.Time { - idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3) + idleTimeout := max(s.idleTimeout, s.rttStats.PTO(true)*3) return s.idleTimeoutStartTime().Add(idleTimeout) } @@ -686,7 +695,7 @@ func (s *connection) nextKeepAliveTime() time.Time { if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { return time.Time{} } - keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) + keepAliveInterval := max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) return s.lastPacketReceivedTime.Add(keepAliveInterval) } @@ -726,6 +735,10 @@ func (s *connection) handleHandshakeComplete() error { s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() + if s.tracer != nil && s.tracer.ChoseALPN != nil { + s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol) + } + // The server applies transport parameters right away, but the client side has to wait for handshake completion. // During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets. if s.perspective == protocol.PerspectiveClient { @@ -771,7 +784,7 @@ func (s *connection) handleHandshakeConfirmed() error { if maxPacketSize == 0 { maxPacketSize = protocol.MaxByteCount } - s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize)) + s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize)) } return nil } @@ -798,15 +811,15 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { var err error destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) break @@ -816,12 +829,12 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if wire.IsLongHeaderPacket(p.data[0]) { hdr, packetData, rest, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { dropReason := logging.PacketDropHeaderParseError if err == wire.ErrUnsupportedVersion { dropReason = logging.PacketDropUnsupportedVersion } - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), dropReason) } s.logger.Debugf("error parsing packet: %s", err) break @@ -829,8 +842,8 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { lastConnID = hdr.DestConnectionID if hdr.Version != s.version { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) } s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) break @@ -888,14 +901,14 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate) } return false } var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedShortHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedShortHeaderPacket( &logging.ShortHeader{ @@ -905,6 +918,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc KeyPhase: keyPhase, }, p.Size(), + p.ecn, frames, ) } @@ -927,22 +941,22 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) }() if hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(hdr, p.data) + return s.handleRetryPacket(hdr, p.data, p.rcvTime) } // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) return false } // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } return false } @@ -958,10 +972,10 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) packet.hdr.Log(s.logger) } - if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { + if pn := packet.hdr.PacketNumber; s.receivedPacketHandler.IsPotentiallyDuplicate(pn, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate) } return false } @@ -976,8 +990,8 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { switch err { case handshake.ErrKeysDropped: - if s.tracer != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) case handshake.ErrKeysNotYetAvailable: @@ -992,16 +1006,16 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P }) case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) default: var headerErr *headerParseError if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) } else { @@ -1013,25 +1027,25 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P return false } -func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { +func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry.") return false } if s.receivedFirstPacket { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since we already received a packet.") return false } destConnID := s.connIDManager.Get() if hdr.SrcConnectionID == destConnID { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") return false @@ -1045,8 +1059,8 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") return false @@ -1057,12 +1071,12 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedRetry != nil { s.tracer.ReceivedRetry(hdr) } newDestConnID := hdr.SrcConnectionID s.receivedRetry = true - if err := s.sentPacketHandler.ResetForRetry(); err != nil { + if err := s.sentPacketHandler.ResetForRetry(rcvTime); err != nil { s.closeLocal(err) return false } @@ -1078,16 +1092,16 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) } return } src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) return @@ -1095,8 +1109,8 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { for _, v := range supportedVersions { if v == s.version { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) } // The Version Negotiation packet contains the version that we offered. // This might be a packet sent by an attacker, or it was corrupted. @@ -1105,7 +1119,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { } s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedVersionNegotiationPacket != nil { s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) @@ -1117,7 +1131,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { s.logger.Infof("No compatible QUIC version found.") return } - if s.tracer != nil { + if s.tracer != nil && s.tracer.NegotiatedVersion != nil { s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) } @@ -1137,8 +1151,8 @@ func (s *connection) handleUnpackedLongHeaderPacket( ) error { if !s.receivedFirstPacket { s.receivedFirstPacket = true - if !s.versionNegotiated && s.tracer != nil { - var clientVersions, serverVersions []protocol.VersionNumber + if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { + var clientVersions, serverVersions []protocol.Version switch s.perspective { case protocol.PerspectiveClient: clientVersions = s.config.Versions @@ -1164,7 +1178,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.StartedConnection != nil { s.tracer.StartedConnection( s.conn.LocalAddr(), s.conn.RemoteAddr(), @@ -1175,7 +1189,8 @@ func (s *connection) handleUnpackedLongHeaderPacket( } } - if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake { + if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake && + !s.droppedInitialKeys { // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // See Section 4.9.1 of RFC 9001. if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { @@ -1188,9 +1203,9 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.keepAlivePingSent = false var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil { log = func(frames []logging.Frame) { - s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames) + s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } } isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) @@ -1336,8 +1351,8 @@ func (s *connection) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: default: - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } } } @@ -1516,7 +1531,7 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr } func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { - if f.Length(s.version) > protocol.MaxDatagramFrameSize { + if f.Length(s.version) > wire.MaxDatagramSize { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "DATAGRAM frame too large", @@ -1562,13 +1577,6 @@ func (s *connection) closeRemote(e error) { }) } -// Close the connection. It sends a NO_ERROR application error. -// It waits until the run loop has stopped before returning -func (s *connection) shutdown() { - s.closeLocal(nil) - <-s.ctx.Done() -} - func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { s.closeLocal(&qerr.ApplicationError{ ErrorCode: code, @@ -1578,6 +1586,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro return nil } +func (s *connection) closeWithTransportError(code TransportErrorCode) { + s.closeLocal(&qerr.TransportError{ErrorCode: code}) + <-s.ctx.Done() +} + func (s *connection) handleCloseError(closeErr *closeError) { e := closeErr.err if e == nil { @@ -1616,13 +1629,13 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.datagramQueue.CloseWithError(e) } - if s.tracer != nil && !errors.As(e, &recreateErr) { + if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) { s.tracer.ClosedConnection(e) } // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) + s.connIDGenerator.ReplaceWithClosed(nil) return } if closeErr.immediate { @@ -1639,11 +1652,11 @@ func (s *connection) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) + s.connIDGenerator.ReplaceWithClosed(connClosePacket) } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil { s.tracer.DroppedEncryptionLevel(encLevel) } s.sentPacketHandler.DropPackets(encLevel) @@ -1651,6 +1664,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro //nolint:exhaustive // only Initial and 0-RTT need special treatment switch encLevel { case protocol.EncryptionInitial: + s.droppedInitialKeys = true s.cryptoStreamHandler.DiscardInitialKeys() case protocol.Encryption0RTT: s.streamsMap.ResetFor0RTT() @@ -1678,7 +1692,7 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters } func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedTransportParameters != nil { s.tracer.ReceivedTransportParameters(params) } if err := s.checkTransportParameters(params); err != nil { @@ -1745,7 +1759,7 @@ func (s *connection) applyTransportParameters() { params := s.peerParams // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) - s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.keepAliveInterval = min(s.config.KeepAlivePeriod, min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.streamsMap.UpdateLimits(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) @@ -1761,9 +1775,8 @@ func (s *connection) applyTransportParameters() { } } -func (s *connection) triggerSending() error { +func (s *connection) triggerSending(now time.Time) error { s.pacingDeadline = time.Time{} - now := time.Now() sendMode := s.sentPacketHandler.SendMode(now) //nolint:exhaustive // No need to handle pacing limited here. @@ -1795,7 +1808,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) case ackhandler.SendPTOHandshake: if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil { return err @@ -1804,7 +1817,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) case ackhandler.SendPTOAppData: if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil { return err @@ -1813,7 +1826,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) default: return fmt.Errorf("BUG: invalid send mode %d", sendMode) } @@ -1830,9 +1843,10 @@ func (s *connection) sendPackets(now time.Time) error { if err != nil { return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + ecn := s.sentPacketHandler.ECNMode(true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) // This is kind of a hack. We need to trigger sending again somehow. s.pacingDeadline = deadlineSendImmediately return nil @@ -1852,7 +1866,7 @@ func (s *connection) sendPackets(now time.Time) error { return err } s.sentFirstPacket = true - if err := s.sendPackedCoalescedPacket(packet, now); err != nil { + if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil { return err } sendMode := s.sentPacketHandler.SendMode(now) @@ -1873,7 +1887,8 @@ func (s *connection) sendPackets(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error { for { buf := getPacketBuffer() - if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { + ecn := s.sentPacketHandler.ECNMode(true) + if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil { if err == errNothingToPack { buf.Release() return nil @@ -1881,7 +1896,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error { return err } - s.sendQueue.Send(buf, buf.Len()) + s.sendQueue.Send(buf, 0, ecn) if s.sendQueue.WouldBlock() { return nil @@ -1906,9 +1921,10 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { buf := getLargePacketBuffer() maxSize := s.mtuDiscoverer.CurrentSize() + ecn := s.sentPacketHandler.ECNMode(true) for { var dontSendMore bool - size, err := s.appendPacket(buf, maxSize, now) + size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now) if err != nil { if err != errNothingToPack { return err @@ -1930,15 +1946,19 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error { } } + // Don't send more packets in this batch if they require a different ECN marking than the previous ones. + nextECN := s.sentPacketHandler.ECNMode(true) + // Append another packet if // 1. The congestion controller and pacer allow sending more // 2. The last packet appended was a full-size packet - // 3. We still have enough space for another full-size packet in the buffer - if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() { + // 3. The next packet will have the same ECN marking + // 4. We still have enough space for another full-size packet in the buffer + if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() { continue } - s.sendQueue.Send(buf, maxSize) + s.sendQueue.Send(buf, uint16(maxSize), ecn) if dontSendMore { return nil @@ -1967,6 +1987,7 @@ func (s *connection) resetPacingDeadline() { func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if !s.handshakeConfirmed { + ecn := s.sentPacketHandler.ECNMode(false) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err @@ -1974,9 +1995,10 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - return s.sendPackedCoalescedPacket(packet, time.Now()) + return s.sendPackedCoalescedPacket(packet, ecn, now) } + ecn := s.sentPacketHandler.ECNMode(true) p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { @@ -1984,9 +2006,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) - s.registerPackedShortHeaderPacket(p, now) - s.sendQueue.Send(buf, buf.Len()) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, ecn, now) + s.sendQueue.Send(buf, 0, ecn) return nil } @@ -2018,24 +2040,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - return s.sendPackedCoalescedPacket(packet, now) + return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now) } -// appendPacket appends a new packet to the given packetBuffer. +// appendOneShortHeaderPacket appends a new packet to the given packetBuffer. // If there was nothing to pack, the returned size is 0. -func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { +func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) { startLen := buf.Len() p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { return 0, err } size := buf.Len() - startLen - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) - s.registerPackedShortHeaderPacket(p, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false) + s.registerPackedShortHeaderPacket(p, ecn, now) return size, nil } -func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { +func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -2044,12 +2066,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) s.connIDManager.SentPacket() } -func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { - s.logCoalescedPacket(packet) +func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error { + s.logCoalescedPacket(packet, ecn) for _, p := range packet.longHdrPackets { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now @@ -2058,8 +2080,9 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.ack != nil { largestAcked = p.ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) - if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { + s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) + if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake && + !s.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { @@ -2075,10 +2098,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if p.Ack != nil { largestAcked = p.Ack.LargestAcked() } - s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer, packet.buffer.Len()) + s.sendQueue.Send(packet.buffer, 0, ecn) return nil } @@ -2100,11 +2123,12 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { if err != nil { return nil, err } - s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) + ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()) + s.logCoalescedPacket(packet, ecn) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn) } -func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { +func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) { // quic-go logging if s.logger.Debug() { p.header.Log(s.logger) @@ -2120,7 +2144,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil { frames := make([]logging.Frame, 0, len(p.frames)) for _, f := range p.frames { frames = append(frames, logutils.ConvertFrame(f.Frame)) @@ -2132,7 +2156,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { if p.ack != nil { ack = logutils.ConvertAckFrame(p.ack) } - s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames) + s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames) } } @@ -2144,11 +2168,12 @@ func (s *connection) logShortHeaderPacket( pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, + ecn protocol.ECN, size protocol.ByteCount, isCoalesced bool, ) { if s.logger.Debug() && !isCoalesced { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn) } // quic-go logging if s.logger.Debug() { @@ -2165,7 +2190,7 @@ func (s *connection) logShortHeaderPacket( } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil { fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) for _, f := range frames { fs = append(fs, logutils.ConvertFrame(f.Frame)) @@ -2185,13 +2210,14 @@ func (s *connection) logShortHeaderPacket( KeyPhase: kp, }, size, + ecn, ack, fs, ) } } -func (s *connection) logCoalescedPacket(packet *coalescedPacket) { +func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) { if s.logger.Debug() { // There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // during which we might call PackCoalescedPacket but just pack a short header packet. @@ -2204,6 +2230,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.KeyPhase, + ecn, packet.shortHdrPacket.Length, false, ) @@ -2216,10 +2243,10 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { } } for _, p := range packet.longHdrPackets { - s.logLongHeaderPacket(p) + s.logLongHeaderPacket(p, ecn) } if p := packet.shortHdrPacket; p != nil { - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true) } } @@ -2285,14 +2312,14 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging panic("shouldn't queue undecryptable packets after handshake completion") } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { - if s.tracer != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil && s.tracer.DroppedPacket != nil { + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return } s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.BufferedPacket != nil { s.tracer.BufferedPacket(pt, p.Size()) } s.undecryptablePackets = append(s.undecryptablePackets, p) @@ -2324,21 +2351,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { } } -func (s *connection) SendMessage(p []byte) error { +func (s *connection) SendDatagram(p []byte) error { if !s.supportsDatagrams() { return errors.New("datagram support disabled") } f := &wire.DatagramFrame{DataLenPresent: true} if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { - return errors.New("message too large") + return &DatagramTooLargeError{ + PeerMaxDatagramFrameSize: int64(s.peerParams.MaxDatagramFrameSize), + } } f.Data = make([]byte, len(p)) copy(f.Data, p) - return s.datagramQueue.AddAndWait(f) + return s.datagramQueue.Add(f) } -func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) { +func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) { if !s.config.EnableDatagrams { return nil, errors.New("datagram support disabled") } @@ -2353,11 +2382,7 @@ func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *connection) getPerspective() protocol.Perspective { - return s.perspective -} - -func (s *connection) GetVersion() protocol.VersionNumber { +func (s *connection) GetVersion() protocol.Version { return s.version } diff --git a/vendor/github.com/quic-go/quic-go/crypto_stream.go b/vendor/github.com/quic-go/quic-go/crypto_stream.go index 5ce2125de..abc7ddcf8 100644 --- a/vendor/github.com/quic-go/quic-go/crypto_stream.go +++ b/vendor/github.com/quic-go/quic-go/crypto_stream.go @@ -6,7 +6,6 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) @@ -30,17 +29,10 @@ type cryptoStreamImpl struct { writeOffset protocol.ByteCount writeBuf []byte - - // Reassemble TLS handshake messages before returning them from GetCryptoData. - // This is only needed because crypto/tls doesn't correctly handle post-handshake messages. - onlyCompleteMsg bool } -func newCryptoStream(onlyCompleteMsg bool) cryptoStream { - return &cryptoStreamImpl{ - queue: newFrameSorter(), - onlyCompleteMsg: onlyCompleteMsg, - } +func newCryptoStream() cryptoStream { + return &cryptoStreamImpl{queue: newFrameSorter()} } func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { @@ -63,7 +55,7 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // could e.g. be a retransmission return nil } - s.highestOffset = utils.Max(s.highestOffset, highestOffset) + s.highestOffset = max(s.highestOffset, highestOffset) if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { return err } @@ -78,20 +70,6 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { - if s.onlyCompleteMsg { - if len(s.msgBuf) < 4 { - return nil - } - msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) - if len(s.msgBuf) < msgLen { - return nil - } - msg := make([]byte, msgLen) - copy(msg, s.msgBuf[:msgLen]) - s.msgBuf = s.msgBuf[msgLen:] - return msg - } - b := s.msgBuf s.msgBuf = nil return b @@ -120,7 +98,7 @@ func (s *cryptoStreamImpl) HasData() bool { func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { f := &wire.CryptoFrame{Offset: s.writeOffset} - n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) f.Data = s.writeBuf[:n] s.writeBuf = s.writeBuf[n:] s.writeOffset += n diff --git a/vendor/github.com/quic-go/quic-go/datagram_queue.go b/vendor/github.com/quic-go/quic-go/datagram_queue.go index ca80d404c..e26285b26 100644 --- a/vendor/github.com/quic-go/quic-go/datagram_queue.go +++ b/vendor/github.com/quic-go/quic-go/datagram_queue.go @@ -4,14 +4,20 @@ import ( "context" "sync" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" ) +const ( + maxDatagramSendQueueLen = 32 + maxDatagramRcvQueueLen = 128 +) + type datagramQueue struct { - sendQueue chan *wire.DatagramFrame - nextFrame *wire.DatagramFrame + sendMx sync.Mutex + sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame] + sent chan struct{} // used to notify Add that a datagram was dequeued rcvMx sync.Mutex rcvQueue [][]byte @@ -22,60 +28,65 @@ type datagramQueue struct { hasData func() - dequeued chan struct{} - logger utils.Logger } func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { return &datagramQueue{ - hasData: hasData, - sendQueue: make(chan *wire.DatagramFrame, 1), - rcvd: make(chan struct{}, 1), - dequeued: make(chan struct{}), - closed: make(chan struct{}), - logger: logger, + hasData: hasData, + rcvd: make(chan struct{}, 1), + sent: make(chan struct{}, 1), + closed: make(chan struct{}), + logger: logger, } } -// AddAndWait queues a new DATAGRAM frame for sending. -// It blocks until the frame has been dequeued. -func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { - select { - case h.sendQueue <- f: - h.hasData() - case <-h.closed: - return h.closeErr - } +// Add queues a new DATAGRAM frame for sending. +// Up to 32 DATAGRAM frames will be queued. +// Once that limit is reached, Add blocks until the queue size has reduced. +func (h *datagramQueue) Add(f *wire.DatagramFrame) error { + h.sendMx.Lock() - select { - case <-h.dequeued: - return nil - case <-h.closed: - return h.closeErr + for { + if h.sendQueue.Len() < maxDatagramSendQueueLen { + h.sendQueue.PushBack(f) + h.sendMx.Unlock() + h.hasData() + return nil + } + select { + case <-h.sent: // drain the queue so we don't loop immediately + default: + } + h.sendMx.Unlock() + select { + case <-h.closed: + return h.closeErr + case <-h.sent: + } + h.sendMx.Lock() } } // Peek gets the next DATAGRAM frame for sending. // If actually sent out, Pop needs to be called before the next call to Peek. func (h *datagramQueue) Peek() *wire.DatagramFrame { - if h.nextFrame != nil { - return h.nextFrame - } - select { - case h.nextFrame = <-h.sendQueue: - h.dequeued <- struct{}{} - default: + h.sendMx.Lock() + defer h.sendMx.Unlock() + if h.sendQueue.Empty() { return nil } - return h.nextFrame + return h.sendQueue.PeekFront() } func (h *datagramQueue) Pop() { - if h.nextFrame == nil { - panic("datagramQueue BUG: Pop called for nil frame") + h.sendMx.Lock() + defer h.sendMx.Unlock() + _ = h.sendQueue.PopFront() + select { + case h.sent <- struct{}{}: + default: } - h.nextFrame = nil } // HandleDatagramFrame handles a received DATAGRAM frame. @@ -84,7 +95,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { copy(data, f.Data) var queued bool h.rcvMx.Lock() - if len(h.rcvQueue) < protocol.DatagramRcvQueueLen { + if len(h.rcvQueue) < maxDatagramRcvQueueLen { h.rcvQueue = append(h.rcvQueue, data) queued = true select { @@ -94,7 +105,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { } h.rcvMx.Unlock() if !queued && h.logger.Debug() { - h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data)) } } diff --git a/vendor/github.com/quic-go/quic-go/errors.go b/vendor/github.com/quic-go/quic-go/errors.go index c9fb0a07b..fda3c9247 100644 --- a/vendor/github.com/quic-go/quic-go/errors.go +++ b/vendor/github.com/quic-go/quic-go/errors.go @@ -61,3 +61,15 @@ func (e *StreamError) Error() string { } return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode) } + +// DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent. +type DatagramTooLargeError struct { + PeerMaxDatagramFrameSize int64 +} + +func (e *DatagramTooLargeError) Is(target error) bool { + _, ok := target.(*DatagramTooLargeError) + return ok +} + +func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" } diff --git a/vendor/github.com/quic-go/quic-go/framer.go b/vendor/github.com/quic-go/quic-go/framer.go index 9409af4c2..1e6219a46 100644 --- a/vendor/github.com/quic-go/quic-go/framer.go +++ b/vendor/github.com/quic-go/quic-go/framer.go @@ -15,14 +15,26 @@ type framer interface { HasData() bool QueueControlFrame(wire.Frame) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) Handle0RTTRejection() error + + // QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length. + // This is a hack. + // It is easier to implement than propagating an error return value in QueueControlFrame. + // The correct solution would be to queue frames with their respective structs. + // See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames. + QueuedTooManyControlFrames() bool } +const ( + maxPathResponses = 256 + maxControlFrames = 16 << 10 +) + type framerI struct { mutex sync.Mutex @@ -31,8 +43,10 @@ type framerI struct { activeStreams map[protocol.StreamID]struct{} streamQueue ringbuffer.RingBuffer[protocol.StreamID] - controlFrameMutex sync.Mutex - controlFrames []wire.Frame + controlFrameMutex sync.Mutex + controlFrames []wire.Frame + pathResponses []*wire.PathResponseFrame + queuedTooManyControlFrames bool } var _ framer = &framerI{} @@ -52,20 +66,48 @@ func (f *framerI) HasData() bool { return true } f.controlFrameMutex.Lock() - hasData = len(f.controlFrames) > 0 - f.controlFrameMutex.Unlock() - return hasData + defer f.controlFrameMutex.Unlock() + return len(f.controlFrames) > 0 || len(f.pathResponses) > 0 } func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Lock() + defer f.controlFrameMutex.Unlock() + + if pr, ok := frame.(*wire.PathResponseFrame); ok { + // Only queue up to maxPathResponses PATH_RESPONSE frames. + // This limit should be high enough to never be hit in practice, + // unless the peer is doing something malicious. + if len(f.pathResponses) >= maxPathResponses { + return + } + f.pathResponses = append(f.pathResponses, pr) + return + } + // This is a hack. + if len(f.controlFrames) >= maxControlFrames { + f.queuedTooManyControlFrames = true + return + } f.controlFrames = append(f.controlFrames, frame) - f.controlFrameMutex.Unlock() } -func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - var length protocol.ByteCount +func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { f.controlFrameMutex.Lock() + defer f.controlFrameMutex.Unlock() + + var length protocol.ByteCount + // add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet + if len(f.pathResponses) > 0 { + frame := f.pathResponses[0] + frameLen := frame.Length(v) + if frameLen <= maxLen { + frames = append(frames, ackhandler.Frame{Frame: frame}) + length += frameLen + f.pathResponses = f.pathResponses[1:] + } + } + for len(f.controlFrames) > 0 { frame := f.controlFrames[len(f.controlFrames)-1] frameLen := frame.Length(v) @@ -76,10 +118,13 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } - f.controlFrameMutex.Unlock() return frames, length } +func (f *framerI) QueuedTooManyControlFrames() bool { + return f.queuedTooManyControlFrames +} + func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Lock() if _, ok := f.activeStreams[id]; !ok { @@ -89,7 +134,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { +func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { startLen := len(frames) var length protocol.ByteCount f.mutex.Lock() diff --git a/vendor/github.com/quic-go/quic-go/interface.go b/vendor/github.com/quic-go/quic-go/interface.go index c3fb2b105..ca8544d8f 100644 --- a/vendor/github.com/quic-go/quic-go/interface.go +++ b/vendor/github.com/quic-go/quic-go/interface.go @@ -16,8 +16,12 @@ import ( // The StreamID is the ID of a QUIC stream. type StreamID = protocol.StreamID +// A Version is a QUIC version number. +type Version = protocol.Version + // A VersionNumber is a QUIC version number. -type VersionNumber = protocol.VersionNumber +// Deprecated: VersionNumber was renamed to Version. +type VersionNumber = Version const ( // Version1 is RFC 9000 @@ -159,6 +163,9 @@ type Connection interface { OpenStream() (Stream, error) // OpenStreamSync opens a new bidirectional QUIC stream. // It blocks until a new stream can be opened. + // There is no signaling to the peer about new streams: + // The peer can only accept the stream after data has been sent on the stream, + // or the stream has been reset or closed. // If the error is non-nil, it satisfies the net.Error interface. // If the connection was closed due to a timeout, Timeout() will be true. OpenStreamSync(context.Context) (Stream, error) @@ -187,10 +194,14 @@ type Connection interface { // Warning: This API should not be considered stable and might change soon. ConnectionState() ConnectionState - // SendMessage sends a message as a datagram, as specified in RFC 9221. - SendMessage([]byte) error - // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. - ReceiveMessage(context.Context) ([]byte, error) + // SendDatagram sends a message using a QUIC datagram, as specified in RFC 9221. + // There is no delivery guarantee for DATAGRAM frames, they are not retransmitted if lost. + // The payload of the datagram needs to fit into a single QUIC packet. + // In addition, a datagram may be dropped before being sent out if the available packet size suddenly decreases. + // If the payload is too large to be sent at the current time, a DatagramTooLargeError is returned. + SendDatagram(payload []byte) error + // ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221. + ReceiveDatagram(context.Context) ([]byte, error) } // An EarlyConnection is a connection that is handshaking. @@ -212,6 +223,9 @@ type EarlyConnection interface { // StatelessResetKey is a key used to derive stateless reset tokens. type StatelessResetKey [32]byte +// TokenGeneratorKey is a key used to encrypt session resumption tokens. +type TokenGeneratorKey = handshake.TokenProtectorKey + // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // It is not able to handle QUIC Connection IDs longer than 20 bytes, // as they are allowed by RFC 8999. @@ -248,9 +262,10 @@ type Config struct { GetConfigForClient func(info *ClientHelloInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. - Versions []VersionNumber + Versions []Version // HandshakeIdleTimeout is the idle timeout before completion of the handshake. - // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. + // If we don't receive any packet from the peer within this time, the connection attempt is aborted. + // Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted. // If this value is zero, the timeout is set to 5 seconds. HandshakeIdleTimeout time.Duration // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. @@ -259,18 +274,6 @@ type Config struct { // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration - // RequireAddressValidation determines if a QUIC Retry packet is sent. - // This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT. - // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. - // If not set, every client is forced to prove its remote address. - RequireAddressValidation func(net.Addr) bool - // MaxRetryTokenAge is the maximum age of a Retry token. - // If not set, it defaults to 5 seconds. Only valid for a server. - MaxRetryTokenAge time.Duration - // MaxTokenAge is the maximum age of the token presented during the handshake, - // for tokens that were issued on a previous connection. - // If not set, it defaults to 24 hours. Only valid for a server. - MaxTokenAge time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set @@ -322,20 +325,23 @@ type Config struct { // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. // If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. DisablePathMTUDiscovery bool - // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. - // This can be useful if version information is exchanged out-of-band. - // It has no effect for a client. - DisableVersionNegotiationPackets bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Only valid for the server. Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer + Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer } +// ClientHelloInfo contains information about an incoming connection attempt. type ClientHelloInfo struct { + // RemoteAddr is the remote address on the Initial packet. + // Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address. RemoteAddr net.Addr + // AddrVerified says if the remote address was verified using QUIC's Retry mechanism. + // Note that the Retry mechanism costs one network roundtrip, + // and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed. + AddrVerified bool } // ConnectionState records basic details about a QUIC connection @@ -345,10 +351,12 @@ type ConnectionState struct { // SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated. // This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams). // If datagram support was negotiated, datagrams can be sent and received using the - // SendMessage and ReceiveMessage methods on the Connection. + // SendDatagram and ReceiveDatagram methods on the Connection. SupportsDatagrams bool // Used0RTT says if 0-RTT resumption was used. Used0RTT bool // Version is the QUIC version of the QUIC connection. - Version VersionNumber + Version Version + // GSO says if generic segmentation offload is used + GSO bool } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go index 2c7cc4fcf..6f890b4d3 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go @@ -14,10 +14,11 @@ func NewAckHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) - return sph, newReceivedPacketHandler(sph, rttStats, logger) + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) + return sph, newReceivedPacketHandler(sph, logger) } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go new file mode 100644 index 000000000..68415ac6c --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ecn.go @@ -0,0 +1,296 @@ +package ackhandler + +import ( + "fmt" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +type ecnState uint8 + +const ( + ecnStateInitial ecnState = iota + ecnStateTesting + ecnStateUnknown + ecnStateCapable + ecnStateFailed +) + +// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type +const numECNTestingPackets = 10 + +type ecnHandler interface { + SentPacket(protocol.PacketNumber, protocol.ECN) + Mode() protocol.ECN + HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) + LostPacket(protocol.PacketNumber) +} + +// The ecnTracker performs ECN validation of a path. +// Once failed, it doesn't do any re-validation of the path. +// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces. +// In order to avoid revealing any internal state to on-path observers, +// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent. +// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4. +type ecnTracker struct { + state ecnState + numSentTesting, numLostTesting uint8 + + firstTestingPacket protocol.PacketNumber + lastTestingPacket protocol.PacketNumber + firstCapablePacket protocol.PacketNumber + + numSentECT0, numSentECT1 int64 + numAckedECT0, numAckedECT1, numAckedECNCE int64 + + tracer *logging.ConnectionTracer + logger utils.Logger +} + +var _ ecnHandler = &ecnTracker{} + +func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker { + return &ecnTracker{ + firstTestingPacket: protocol.InvalidPacketNumber, + lastTestingPacket: protocol.InvalidPacketNumber, + firstCapablePacket: protocol.InvalidPacketNumber, + state: ecnStateInitial, + logger: logger, + tracer: tracer, + } +} + +func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { + //nolint:exhaustive // These are the only ones we need to take care of. + switch ecn { + case protocol.ECNNon: + return + case protocol.ECT0: + e.numSentECT0++ + case protocol.ECT1: + e.numSentECT1++ + case protocol.ECNUnsupported: + if e.state != ecnStateFailed { + panic("didn't expect ECN to be unsupported") + } + default: + panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn)) + } + + if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber { + e.firstCapablePacket = pn + } + + if e.state != ecnStateTesting { + return + } + + e.numSentTesting++ + if e.firstTestingPacket == protocol.InvalidPacketNumber { + e.firstTestingPacket = pn + } + if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateUnknown + e.lastTestingPacket = pn + } +} + +func (e *ecnTracker) Mode() protocol.ECN { + switch e.state { + case ecnStateInitial: + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateTesting + return e.Mode() + case ecnStateTesting, ecnStateCapable: + return protocol.ECT0 + case ecnStateUnknown, ecnStateFailed: + return protocol.ECNNon + default: + panic(fmt.Sprintf("unknown ECN state: %d", e.state)) + } +} + +func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { + if e.state != ecnStateTesting && e.state != ecnStateUnknown { + return + } + if !e.isTestingPacket(pn) { + return + } + e.numLostTesting++ + // Only proceed if we have sent all 10 testing packets. + if e.state != ecnStateUnknown { + return + } + if e.numLostTesting >= e.numSentTesting { + e.logger.Debugf("Disabling ECN. All testing packets were lost.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) + } + e.state = ecnStateFailed + return + } + // Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked + e.failIfMangled() +} + +// HandleNewlyAcked handles the ECN counts on an ACK frame. +// It must only be called for ACK frames that increase the largest acknowledged packet number, +// see section 13.4.2.1 of RFC 9000. +func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) { + if e.state == ecnStateFailed { + return false + } + + // ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds + // the total number of packets sent with each corresponding ECT codepoint. + if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 { + e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) + } + e.state = ecnStateFailed + return false + } + + // Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged. + var ackedECT0, ackedECT1 int64 + for _, p := range packets { + //nolint:exhaustive // We only ever send ECT(0) and ECT(1). + switch e.ecnMarking(p.PacketNumber) { + case protocol.ECT0: + ackedECT0++ + case protocol.ECT1: + ackedECT1++ + } + } + + // If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1) + // codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame. + // This check detects: + // * paths that bleach all ECN marks, and + // * peers that don't report any ECN counts + if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 { + e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + } + e.state = ecnStateFailed + return false + } + + // Determine the increase in ECT0, ECT1 and ECNCE marks + newECT0 := ect0 - e.numAckedECT0 + newECT1 := ect1 - e.numAckedECT1 + newECNCE := ecnce - e.numAckedECNCE + + // We're only processing ACKs that increase the Largest Acked. + // Therefore, the ECN counters should only ever increase. + // Any decrease means that the peer's counting logic is broken. + if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 { + e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) + } + e.state = ecnStateFailed + return false + } + + // ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number + // of newly acknowledged packets that were originally sent with an ECT(0) marking. + // This could be the result of (partial) bleaching. + if newECT0+newECNCE < ackedECT0 { + e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + } + e.state = ecnStateFailed + return false + } + // Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than + // the number of newly acknowledged packets sent with an ECT(1) marking. + if newECT1+newECNCE < ackedECT1 { + e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) + } + e.state = ecnStateFailed + return false + } + + // update our counters + e.numAckedECT0 = ect0 + e.numAckedECT1 = ect1 + e.numAckedECNCE = ecnce + + // Detect mangling (a path remarking all ECN-marked testing packets as CE), + // once all 10 testing packets have been sent out. + if e.state == ecnStateUnknown { + e.failIfMangled() + if e.state == ecnStateFailed { + return false + } + } + if e.state == ecnStateTesting || e.state == ecnStateUnknown { + var ackedTestingPacket bool + for _, p := range packets { + if e.isTestingPacket(p.PacketNumber) { + ackedTestingPacket = true + break + } + } + // This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE). + if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) { + e.logger.Debugf("ECN capability confirmed.") + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) + } + e.state = ecnStateCapable + } + } + + // Don't trust CE marks before having confirmed ECN capability of the path. + // Otherwise, mangling would be misinterpreted as actual congestion. + return e.state == ecnStateCapable && newECNCE > 0 +} + +// failIfMangled fails ECN validation if all testing packets are lost or CE-marked. +func (e *ecnTracker) failIfMangled() { + numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting) + if e.numSentECT0+e.numSentECT1 > numAckedECNCE { + return + } + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { + e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) + } + e.state = ecnStateFailed +} + +func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN { + if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber { + return protocol.ECNNon + } + if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber { + return protocol.ECT0 + } + if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber { + return protocol.ECNNon + } + // We don't need to deal with the case when ECN validation fails, + // since we're ignoring any ECN counts reported in ACK frames in that case. + return protocol.ECT0 +} + +func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool { + if e.firstTestingPacket == protocol.InvalidPacketNumber { + return false + } + return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber) +} diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go index ba95f8a9b..ba8cbbdae 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go @@ -10,13 +10,13 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) + SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool) // ReceivedAck processes an ACK frame. // It does not store a copy of the frame. - ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) - ResetForRetry() error + ResetForRetry(rcvTime time.Time) error SetHandshakeConfirmed() // The SendMode determines if and what kind of packets can be sent. @@ -29,6 +29,7 @@ type SentPacketHandler interface { // only to be called once the handshake is complete QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ + ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber @@ -44,7 +45,7 @@ type sentPacketTracker interface { // ReceivedPacketHandler handles ACKs needed to send for incoming packets type ReceivedPacketHandler interface { IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool - ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error + ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, ackEliciting bool) error DropPackets(protocol.EncryptionLevel) GetAlarmTimeout() time.Time diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go index d61783671..0031e6b1c 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go @@ -2,5 +2,8 @@ package ackhandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler" +type ECNHandler = ecnHandler diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go index e84171e3c..4a9db8635 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go @@ -80,5 +80,5 @@ func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) - p.period = utils.Min(2*p.period, p.maxPeriod) + p.period = min(2*p.period, p.maxPeriod) } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go index e11ee1eeb..1175c790b 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go @@ -14,23 +14,19 @@ type receivedPacketHandler struct { initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker - appDataPackets *receivedPacketTracker + appDataPackets appDataReceivedPacketTracker lowest1RTTPacket protocol.PacketNumber } var _ ReceivedPacketHandler = &receivedPacketHandler{} -func newReceivedPacketHandler( - sentPackets sentPacketTracker, - rttStats *utils.RTTStats, - logger utils.Logger, -) ReceivedPacketHandler { +func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler { return &receivedPacketHandler{ sentPackets: sentPackets, - initialPackets: newReceivedPacketTracker(rttStats, logger), - handshakePackets: newReceivedPacketTracker(rttStats, logger), - appDataPackets: newReceivedPacketTracker(rttStats, logger), + initialPackets: newReceivedPacketTracker(), + handshakePackets: newReceivedPacketTracker(), + appDataPackets: *newAppDataReceivedPacketTracker(logger), lowest1RTTPacket: protocol.InvalidPacketNumber, } } @@ -40,29 +36,29 @@ func (h *receivedPacketHandler) ReceivedPacket( ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, - shouldInstigateAck bool, + ackEliciting bool, ) error { h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: - return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.EncryptionHandshake: // The Handshake packet number space might already have been dropped as a result // of processing the CRYPTO frame that was contained in this packet. if h.handshakePackets == nil { return nil } - return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) } - return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting) case protocol.Encryption1RTT: if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } - if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck); err != nil { + if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil { return err } h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) @@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { } func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { - var initialAlarm, handshakeAlarm time.Time - if h.initialPackets != nil { - initialAlarm = h.initialPackets.GetAlarmTimeout() - } - if h.handshakePackets != nil { - handshakeAlarm = h.handshakePackets.GetAlarmTimeout() - } - oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() - return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) + return h.appDataPackets.GetAlarmTimeout() } func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { - var ack *wire.AckFrame //nolint:exhaustive // 0-RTT packets can't contain ACK frames. switch encLevel { case protocol.EncryptionInitial: if h.initialPackets != nil { - ack = h.initialPackets.GetAckFrame(onlyIfQueued) + return h.initialPackets.GetAckFrame() } + return nil case protocol.EncryptionHandshake: if h.handshakePackets != nil { - ack = h.handshakePackets.GetAckFrame(onlyIfQueued) + return h.handshakePackets.GetAckFrame() } + return nil case protocol.Encryption1RTT: - // 0-RTT packets can't contain ACK frames return h.appDataPackets.GetAckFrame(onlyIfQueued) default: + // 0-RTT packets can't contain ACK frames return nil } - // For Initial and Handshake ACKs, the delay time is ignored by the receiver. - // Set it to 0 in order to save bytes. - if ack != nil { - ack.DelayTime = 0 - } - return ack } func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go index b18838663..08af6f1ee 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go @@ -9,93 +9,147 @@ import ( "github.com/quic-go/quic-go/internal/wire" ) -// number of ack-eliciting packets received before sending an ack. +// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space. +// Every received packet is acknowledged immediately. +type receivedPacketTracker struct { + ect0, ect1, ecnce uint64 + + packetHistory receivedPacketHistory + + lastAck *wire.AckFrame + hasNewAck bool // true as soon as we received an ack-eliciting new packet +} + +func newReceivedPacketTracker() *receivedPacketTracker { + return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()} +} + +func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if isNew := h.packetHistory.ReceivedPacket(pn); !isNew { + return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) + } + + //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE. + switch ecn { + case protocol.ECT0: + h.ect0++ + case protocol.ECT1: + h.ect1++ + case protocol.ECNCE: + h.ecnce++ + } + if !ackEliciting { + return nil + } + h.hasNewAck = true + return nil +} + +func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame { + if !h.hasNewAck { + return nil + } + + // This function always returns the same ACK frame struct, filled with the most recent values. + ack := h.lastAck + if ack == nil { + ack = &wire.AckFrame{} + } + ack.Reset() + ack.ECT0 = h.ect0 + ack.ECT1 = h.ect1 + ack.ECNCE = h.ecnce + ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) + + h.lastAck = ack + h.hasNewAck = false + return ack +} + +func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { + return h.packetHistory.IsPotentiallyDuplicate(pn) +} + +// number of ack-eliciting packets received before sending an ACK const packetsBeforeAck = 2 -type receivedPacketTracker struct { - largestObserved protocol.PacketNumber - ignoreBelow protocol.PacketNumber - largestObservedReceivedTime time.Time - ect0, ect1, ecnce uint64 +// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space. +// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached. +type appDataReceivedPacketTracker struct { + receivedPacketTracker - packetHistory *receivedPacketHistory + largestObservedRcvdTime time.Time - maxAckDelay time.Duration - rttStats *utils.RTTStats + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber - hasNewAck bool // true as soon as we received an ack-eliciting new packet - ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets + maxAckDelay time.Duration + ackQueued bool // true if we need send a new ACK ackElicitingPacketsReceivedSinceLastAck int ackAlarm time.Time - lastAck *wire.AckFrame logger utils.Logger } -func newReceivedPacketTracker( - rttStats *utils.RTTStats, - logger utils.Logger, -) *receivedPacketTracker { - return &receivedPacketTracker{ - packetHistory: newReceivedPacketHistory(), - maxAckDelay: protocol.MaxAckDelay, - rttStats: rttStats, - logger: logger, +func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker { + h := &appDataReceivedPacketTracker{ + receivedPacketTracker: *newReceivedPacketTracker(), + maxAckDelay: protocol.MaxAckDelay, + logger: logger, } + return h } -func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) error { - if isNew := h.packetHistory.ReceivedPacket(packetNumber); !isNew { - return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", packetNumber) +func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil { + return err } - - isMissing := h.isMissing(packetNumber) - if packetNumber >= h.largestObserved { - h.largestObserved = packetNumber - h.largestObservedReceivedTime = rcvTime + if pn >= h.largestObserved { + h.largestObserved = pn + h.largestObservedRcvdTime = rcvTime } - - if shouldInstigateAck { - h.hasNewAck = true + if !ackEliciting { + return nil } - if shouldInstigateAck { - h.maybeQueueAck(packetNumber, rcvTime, isMissing) + h.ackElicitingPacketsReceivedSinceLastAck++ + isMissing := h.isMissing(pn) + if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) { + h.ackQueued = true + h.ackAlarm = time.Time{} // cancel the ack alarm } - switch ecn { - case protocol.ECNNon: - case protocol.ECT0: - h.ect0++ - case protocol.ECT1: - h.ect1++ - case protocol.ECNCE: - h.ecnce++ + if !h.ackQueued { + // No ACK queued, but we'll need to acknowledge the packet after max_ack_delay. + h.ackAlarm = rcvTime.Add(h.maxAckDelay) + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) + } } return nil } // IgnoreBelow sets a lower limit for acknowledging packets. // Packets with packet numbers smaller than p will not be acked. -func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { - if p <= h.ignoreBelow { +func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { + if pn <= h.ignoreBelow { return } - h.ignoreBelow = p - h.packetHistory.DeleteBelow(p) + h.ignoreBelow = pn + h.packetHistory.DeleteBelow(pn) if h.logger.Debug() { - h.logger.Debugf("\tIgnoring all packets below %d.", p) + h.logger.Debugf("\tIgnoring all packets below %d.", pn) } } // isMissing says if a packet was reported missing in the last ACK. -func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { +func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool { if h.lastAck == nil || p < h.ignoreBelow { return false } return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) } -func (h *receivedPacketTracker) hasNewMissingPackets() bool { +func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool { if h.lastAck == nil { return false } @@ -103,31 +157,21 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool { return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 } -// maybeQueueAck queues an ACK, if necessary. -func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { +func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool { // always acknowledge the first packet if h.lastAck == nil { - if !h.ackQueued { - h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") - } - h.ackQueued = true - return + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + return true } - if h.ackQueued { - return - } - - h.ackElicitingPacketsReceivedSinceLastAck++ - // Send an ACK if this packet was reported missing in an ACK sent before. // Ack decimation with reordering relies on the timer to send an ACK, but if - // missing packets we reported in the previous ack, send an ACK immediately. + // missing packets we reported in the previous ACK, send an ACK immediately. if wasMissing { if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) } - h.ackQueued = true + return true } // send an ACK every 2 ack-eliciting packets @@ -135,62 +179,42 @@ func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) } - h.ackQueued = true - } else if h.ackAlarm.IsZero() { - if h.logger.Debug() { - h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) - } - h.ackAlarm = rcvTime.Add(h.maxAckDelay) + return true } - // Queue an ACK if there are new missing packets to report. + // queue an ACK if there are new missing packets to report if h.hasNewMissingPackets() { h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") - h.ackQueued = true + return true } - if h.ackQueued { - // cancel the ack alarm - h.ackAlarm = time.Time{} + // queue an ACK if the packet was ECN-CE marked + if ecn == protocol.ECNCE { + h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.") + return true } + return false } -func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { - if !h.hasNewAck { - return nil - } +func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { now := time.Now() - if onlyIfQueued { - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + if onlyIfQueued && !h.ackQueued { + if h.ackAlarm.IsZero() || h.ackAlarm.After(now) { return nil } - if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + if h.logger.Debug() && !h.ackAlarm.IsZero() { h.logger.Debugf("Sending ACK because the ACK timer expired.") } } - - // This function always returns the same ACK frame struct, filled with the most recent values. - ack := h.lastAck + ack := h.receivedPacketTracker.GetAckFrame() if ack == nil { - ack = &wire.AckFrame{} + return nil } - ack.Reset() - ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) - ack.ECT0 = h.ect0 - ack.ECT1 = h.ect1 - ack.ECNCE = h.ecnce - ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) - - h.lastAck = ack - h.ackAlarm = time.Time{} + ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime)) h.ackQueued = false - h.hasNewAck = false + h.ackAlarm = time.Time{} h.ackElicitingPacketsReceivedSinceLastAck = 0 return ack } -func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } - -func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { - return h.packetHistory.IsPotentiallyDuplicate(pn) -} +func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go index c955da6e9..3cef89239 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go @@ -92,9 +92,12 @@ type sentPacketHandler struct { // The alarm timeout alarm time.Time + enableECN bool + ecnTracker ecnHandler + perspective protocol.Perspective - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -110,8 +113,9 @@ func newSentPacketHandler( initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, clientAddressValidated bool, + enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) *sentPacketHandler { congestion := congestion.NewCubicSender( @@ -122,7 +126,7 @@ func newSentPacketHandler( tracer, ) - return &sentPacketHandler{ + h := &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, initialPackets: newPacketNumberSpace(initialPN, false), @@ -134,6 +138,11 @@ func newSentPacketHandler( tracer: tracer, logger: logger, } + if enableECN { + h.enableECN = true + h.ecnTracker = newECNTracker(logger, tracer) + } + return h } func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { @@ -187,7 +196,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 @@ -228,6 +237,7 @@ func (h *sentPacketHandler) SentPacket( streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, + ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool, ) { @@ -235,7 +245,7 @@ func (h *sentPacketHandler) SentPacket( pnSpace := h.getPacketNumberSpace(encLevel) if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { - for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ { + for p := max(0, pnSpace.largestSent+1); p < pn; p++ { h.logger.Debugf("Skipping packet number %d", p) } } @@ -252,6 +262,10 @@ func (h *sentPacketHandler) SentPacket( } h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { + h.ecnTracker.SentPacket(pn, ecn) + } + if !isAckEliciting { pnSpace.history.SentNonAckElicitingPacket(pn) if !h.peerCompletedAddressValidation { @@ -272,7 +286,7 @@ func (h *sentPacketHandler) SentPacket( p.includedInBytesInFlight = true pnSpace.history.SentAckElicitingPacket(p) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } h.setLossDetectionTimer() @@ -302,8 +316,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) - // Servers complete address validation when a protected packet is received. if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { @@ -324,7 +336,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // don't use the ack delay for Initial and Handshake packets var ackDelay time.Duration if encLevel == protocol.Encryption1RTT { - ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay()) + ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay()) } h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) if h.logger.Debug() { @@ -333,6 +345,17 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.congestion.MaybeExitSlowStart() } } + + // Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked. + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { + congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) + if congested { + h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight) + } + } + + pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked) + if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } @@ -353,14 +376,14 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 } h.numProbesToSend = 0 - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } @@ -423,7 +446,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL for _, p := range h.ackedPackets { if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { - h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1) + h.lowestNotConfirmedAcked = max(h.lowestNotConfirmedAcked, p.LargestAcked+1) } for _, f := range p.Frames { @@ -439,7 +462,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err } - if h.tracer != nil { + if h.tracer != nil && h.tracer.AcknowledgedPacket != nil { h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) } } @@ -532,7 +555,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) } return @@ -543,7 +566,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -555,7 +578,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -568,14 +591,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !oldAlarm.IsZero() { h.alarm = time.Time{} h.logger.Debugf("Canceling loss detection timer. No PTO needed..") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } return } h.alarm = ptoTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) } } @@ -584,11 +607,11 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} - maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + maxRTT := float64(max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) lossDelay := time.Duration(timeThreshold * maxRTT) // Minimum time of granularity before packets are deemed lost. - lossDelay = utils.Max(lossDelay, protocol.TimerGranularity) + lossDelay = max(lossDelay, protocol.TimerGranularity) // Packets sent before this time are deemed lost. lostSendTime := now.Add(-lossDelay) @@ -606,7 +629,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) } } @@ -616,7 +639,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) } } @@ -635,7 +658,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) if !p.IsPathMTUProbePacket { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight) + } + if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { + h.ecnTracker.LostPacket(p.PacketNumber) } } } @@ -650,7 +676,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { if h.logger.Debug() { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerExpired != nil { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection @@ -687,8 +713,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) - h.tracer.UpdatedPTOCount(h.ptoCount) + if h.tracer.LossTimerExpired != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + } + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(h.ptoCount) + } } h.numProbesToSend += 2 //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. @@ -712,6 +742,16 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { return h.alarm } +func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN { + if !h.enableECN { + return protocol.ECNUnsupported + } + if !isShortHeaderPacket { + return protocol.ECNNon + } + return h.ecnTracker.Mode() +} + func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) pn := pnSpace.pns.Peek() @@ -824,7 +864,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { p.Frames = nil } -func (h *sentPacketHandler) ResetForRetry() error { +func (h *sentPacketHandler) ResetForRetry(now time.Time) error { h.bytesInFlight = 0 var firstPacketSendTime time.Time h.initialPackets.history.Iterate(func(p *packet) (bool, error) { @@ -850,12 +890,11 @@ func (h *sentPacketHandler) ResetForRetry() error { // Otherwise, we don't know which Initial the Retry was sent in response to. if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. - now := time.Now() - h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) + h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } @@ -864,8 +903,10 @@ func (h *sentPacketHandler) ResetForRetry() error { oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { - h.tracer.UpdatedPTOCount(0) - if !oldAlarm.IsZero() { + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(0) + } + if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go index a73cf82aa..4e30de650 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go @@ -5,7 +5,6 @@ import ( "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" ) // This cubic implementation is based on the one found in Chromiums's QUIC @@ -187,7 +186,7 @@ func (c *Cubic) CongestionWindowAfterAck( targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow } // Limit the CWND increase to half the acked bytes. - targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + targetCongestionWindow = min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) // Increase the window by approximately Alpha * 1 MSS of bytes every // time we ack an estimated tcp window of bytes. For small diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go index 2e5084751..a1b06ab34 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go @@ -56,7 +56,7 @@ type cubicSender struct { maxDatagramSize protocol.ByteCount lastState logging.CongestionState - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer } var ( @@ -70,7 +70,7 @@ func NewCubicSender( rttStats *utils.RTTStats, initialMaxDatagramSize protocol.ByteCount, reno bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { return newCubicSender( clock, @@ -90,7 +90,7 @@ func newCubicSender( initialMaxDatagramSize, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { c := &cubicSender{ rttStats: rttStats, @@ -108,7 +108,7 @@ func newCubicSender( maxDatagramSize: initialMaxDatagramSize, } c.pacer = newPacer(c.BandwidthEstimate) - if c.tracer != nil { + if c.tracer != nil && c.tracer.UpdatedCongestionState != nil { c.lastState = logging.CongestionStateSlowStart c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) } @@ -178,7 +178,7 @@ func (c *cubicSender) OnPacketAcked( priorInFlight protocol.ByteCount, eventTime time.Time, ) { - c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber) + c.largestAckedPacketNumber = max(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { return } @@ -188,7 +188,7 @@ func (c *cubicSender) OnPacketAcked( } } -func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { +func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // already sent should be treated as a single loss event, since it's expected. if packetNumber <= c.largestSentAtLastCutback { @@ -246,7 +246,7 @@ func (c *cubicSender) maybeIncreaseCwnd( c.numAckedPackets = 0 } } else { - c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + c.congestionWindow = min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) } } @@ -296,7 +296,7 @@ func (c *cubicSender) OnConnectionMigration() { } func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { - if c.tracer == nil || new == c.lastState { + if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState { return } c.tracer.UpdatedCongestionState(new) diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go b/vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go index b2f7c908e..9679d9e46 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go @@ -4,7 +4,6 @@ import ( "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" ) // Note(pwestin): the magic clamping numbers come from the original code in @@ -75,8 +74,8 @@ func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT ti // Divide minRTT by 8 to get a rtt increase threshold for exiting. minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) // Ensure the rtt threshold is never less than 2ms or more than 16ms. - minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) - minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + minRTTincreaseThresholdUs = min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { s.hystartFound = true diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go b/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go index 484bd5f81..881f453b6 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go @@ -14,7 +14,7 @@ type SendAlgorithm interface { CanSend(bytesInFlight protocol.ByteCount) bool MaybeExitSlowStart() OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) - OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) + OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) OnRetransmissionTimeout(packetsRetransmitted bool) SetMaxDatagramSize(protocol.ByteCount) } diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go b/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go index 09ea26806..34d3d1d09 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go @@ -1,11 +1,9 @@ package congestion import ( - "math" "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" ) const maxBurstSizePackets = 10 @@ -26,7 +24,7 @@ func newPacer(getBandwidth func() Bandwidth) *pacer { bw := uint64(getBandwidth() / BytesPerSecond) // Use a slightly higher value than the actual measured bandwidth. // RTT variations then won't result in under-utilization of the congestion window. - // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, + // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. return bw * 5 / 4 }, @@ -37,7 +35,7 @@ func newPacer(getBandwidth func() Bandwidth) *pacer { func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) { budget := p.Budget(sendTime) - if size > budget { + if size >= budget { p.budgetAtLastSent = 0 } else { p.budgetAtLastSent = budget - size @@ -53,11 +51,11 @@ func (p *pacer) Budget(now time.Time) protocol.ByteCount { if budget < 0 { // protect against overflows budget = protocol.MaxByteCount } - return utils.Min(p.maxBurstSize(), budget) + return min(p.maxBurstSize(), budget) } func (p *pacer) maxBurstSize() protocol.ByteCount { - return utils.Max( + return max( protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9, maxBurstSizePackets*p.maxDatagramSize, ) @@ -69,10 +67,16 @@ func (p *pacer) TimeUntilSend() time.Time { if p.budgetAtLastSent >= p.maxDatagramSize { return time.Time{} } - return p.lastSentTime.Add(utils.Max( - protocol.MinPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond, - )) + diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) + bw := p.adjustedBandwidth() + // We might need to round up this value. + // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. + d := diff / bw + // this is effectively a math.Ceil, but using only integer math + if diff%bw > 0 { + d++ + } + return p.lastSentTime.Add(max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond)) } func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) { diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go index f3f24a60e..3d88d577e 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go @@ -48,10 +48,12 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { } // UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame. -func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) { if offset > c.sendWindow { c.sendWindow = offset + return true } + return false } func (c *baseFlowController) sendWindowSize() protocol.ByteCount { @@ -107,7 +109,7 @@ func (c *baseFlowController) maybeAdjustWindowSize() { now := time.Now() if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { // window is consumed too fast, try to increase the window size - newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize) + newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize) if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { c.receiveWindowSize = newSize } diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go index 13e69d6c4..8504cdcf5 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -87,7 +87,7 @@ func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCoun c.mutex.Lock() if inc > c.receiveWindowSize { c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) - newSize := utils.Min(inc, c.maxReceiveWindowSize) + newSize := min(inc, c.maxReceiveWindowSize) if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { c.receiveWindowSize = newSize } diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go index 946519d52..fc5f9de0a 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go @@ -5,7 +5,7 @@ import "github.com/quic-go/quic-go/internal/protocol" type flowController interface { // for sending SendWindowSize() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) + UpdateSendWindow(protocol.ByteCount) (updated bool) AddBytesSent(protocol.ByteCount) // for receiving AddBytesRead(protocol.ByteCount) @@ -16,12 +16,11 @@ type flowController interface { // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController - // for receiving - // UpdateHighestReceived should be called when a new highest offset is received + // UpdateHighestReceived is called when a new highest offset is received // final has to be to true if this is the final offset of the stream, // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // Abandon should be called when reading from the stream is aborted early, + // Abandon is called when reading from the stream is aborted early, // and there won't be any further calls to AddBytesRead. Abandon() } diff --git a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go index 1770a9c84..1a69fb2b3 100644 --- a/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -123,7 +123,7 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) + return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) } func (c *streamFlowController) shouldQueueWindowUpdate() bool { diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/aead.go b/vendor/github.com/quic-go/quic-go/internal/handshake/aead.go index 6aa89fb3f..1baf5d6b0 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/aead.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/aead.go @@ -1,14 +1,12 @@ package handshake import ( - "crypto/cipher" "encoding/binary" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" ) -func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { +func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { @@ -21,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb } type longHeaderSealer struct { - aead cipher.AEAD + aead *xorNonceAEAD headerProtector headerProtector - - // use a single slice to avoid allocations - nonceBuf []byte + nonceBuf [8]byte } var _ LongHeaderSealer = &longHeaderSealer{} -func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { +func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer { + if aead.NonceSize() != 8 { + panic("unexpected nonce size") + } return &longHeaderSealer{ aead: aead, headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), } } func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return s.aead.Seal(dst, s.nonceBuf, src, ad) + binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn)) + return s.aead.Seal(dst, s.nonceBuf[:], src, ad) } func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { @@ -54,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int { } type longHeaderOpener struct { - aead cipher.AEAD + aead *xorNonceAEAD headerProtector headerProtector highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) - // use a single slice to avoid allocations - nonceBuf []byte + // use a single array to avoid allocations + nonceBuf [8]byte } var _ LongHeaderOpener = &longHeaderOpener{} -func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { +func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener { + if aead.NonceSize() != 8 { + panic("unexpected nonce size") + } return &longHeaderOpener{ aead: aead, headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), } } @@ -77,12 +75,10 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire } func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) + binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn)) + dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad) if err == nil { - o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn) + o.highestRcvdPN = max(o.highestRcvdPN, pn) } else { err = ErrDecryptionFailed } diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go b/vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go index 265231f0c..d8a381daf 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/cipher_suite.go @@ -18,7 +18,7 @@ type cipherSuite struct { ID uint16 Hash crypto.Hash KeyLen int - AEAD func(key, nonceMask []byte) cipher.AEAD + AEAD func(key, nonceMask []byte) *xorNonceAEAD } func (s cipherSuite) IVLen() int { return aeadNonceLength } @@ -36,7 +36,7 @@ func getCipherSuite(id uint16) *cipherSuite { } } -func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD { +func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD { if len(nonceMask) != aeadNonceLength { panic("tls: internal error: wrong nonce length") } @@ -54,7 +54,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD { return ret } -func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD { +func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD { if len(nonceMask) != aeadNonceLength { panic("tls: internal error: wrong nonce length") } diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go index 332dbd7ef..adf74fe74 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "strings" - "sync" "sync/atomic" "time" @@ -25,15 +24,15 @@ type quicVersionContextKey struct{} var QUICVersionContextKey = &quicVersionContextKey{} -const clientSessionStateRevision = 3 +const clientSessionStateRevision = 4 type cryptoSetup struct { tlsConf *tls.Config - conn *qtls.QUICConn + conn *tls.QUICConn events []Event - version protocol.VersionNumber + version protocol.Version ourParams *wire.TransportParameters peerParams *wire.TransportParameters @@ -43,13 +42,11 @@ type cryptoSetup struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger perspective protocol.Perspective - mutex sync.Mutex // protects all members below - handshakeCompleteTime time.Time zeroRTTOpener LongHeaderOpener // only set for the server @@ -77,9 +74,9 @@ func NewCryptoSetupClient( tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, - version protocol.VersionNumber, + version protocol.Version, ) CryptoSetup { cs := newCryptoSetup( connID, @@ -93,11 +90,12 @@ func NewCryptoSetupClient( tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + quicConf := &tls.QUICConfig{TLSConfig: tlsConf} qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) cs.tlsConf = tlsConf + cs.allow0RTT = enable0RTT - cs.conn = qtls.QUICClient(quicConf) + cs.conn = tls.QUICClient(quicConf) cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) return cs @@ -111,9 +109,9 @@ func NewCryptoSetupServer( tlsConf *tls.Config, allow0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, - version protocol.VersionNumber, + version protocol.Version, ) CryptoSetup { cs := newCryptoSetup( connID, @@ -126,12 +124,12 @@ func NewCryptoSetupServer( ) cs.allow0RTT = allow0RTT - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) + quicConf := &tls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket) addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) cs.tlsConf = quicConf.TLSConfig - cs.conn = qtls.QUICServer(quicConf) + cs.conn = tls.QUICServer(quicConf) return cs } @@ -146,6 +144,9 @@ func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} c, err := gcfc(info) if c != nil { + c = c.Clone() + // This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted. + c.MinVersion = tls.VersionTLS13 // We're returning a tls.Config here, so we need to apply this recursively. addConnToClientHelloInfo(c, localAddr, remoteAddr) } @@ -165,13 +166,13 @@ func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, - version protocol.VersionNumber, + version protocol.Version, ) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) - if tracer != nil { + if tracer != nil && tracer.UpdatedKeyFromTLS != nil { tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -193,7 +194,7 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) h.initialSealer = initialSealer h.initialOpener = initialOpener - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -260,28 +261,28 @@ func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLev } } -func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { +func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) { switch ev.Kind { - case qtls.QUICNoEvent: + case tls.QUICNoEvent: return true, nil - case qtls.QUICSetReadSecret: - h.SetReadKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetReadSecret: + h.setReadKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICSetWriteSecret: - h.SetWriteKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetWriteSecret: + h.setWriteKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICTransportParameters: + case tls.QUICTransportParameters: return false, h.handleTransportParameters(ev.Data) - case qtls.QUICTransportParametersRequired: + case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) return false, nil - case qtls.QUICRejectedEarlyData: + case tls.QUICRejectedEarlyData: h.rejected0RTT() return false, nil - case qtls.QUICWriteData: - h.WriteRecord(ev.Level, ev.Data) + case tls.QUICWriteData: + h.writeRecord(ev.Level, ev.Data) return false, nil - case qtls.QUICHandshakeDone: + case tls.QUICHandshakeDone: h.handshakeComplete() return false, nil default: @@ -309,55 +310,75 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) error { } // must be called after receiving the transport parameters -func (h *cryptoSetup) marshalDataForSessionState() []byte { +func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, clientSessionStateRevision) b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) - return h.peerParams.MarshalForSessionTicket(b) + if earlyData { + // only save the transport parameters for 0-RTT enabled session tickets + return h.peerParams.MarshalForSessionTicket(b) + } + return b } -func (h *cryptoSetup) handleDataFromSessionState(data []byte) { - tp, err := h.handleDataFromSessionStateImpl(data) +func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) { + rtt, tp, err := decodeDataFromSessionState(data, earlyData) if err != nil { h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) return } - h.zeroRTTParameters = tp + h.rttStats.SetInitialRTT(rtt) + // The session ticket might have been saved from a connection that allowed 0-RTT, + // and therefore contain transport parameters. + // Only use them if 0-RTT is actually used on the new connection. + if tp != nil && h.allow0RTT { + h.zeroRTTParameters = tp + return true + } + return false } -func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { +func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) { r := bytes.NewReader(data) ver, err := quicvarint.Read(r) if err != nil { - return nil, err + return 0, nil, err } if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) + return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) } - rtt, err := quicvarint.Read(r) + rttEncoded, err := quicvarint.Read(r) if err != nil { - return nil, err + return 0, nil, err + } + rtt := time.Duration(rttEncoded) * time.Microsecond + if !earlyData { + return rtt, nil, nil } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) var tp wire.TransportParameters if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err + return 0, nil, err } - return &tp, nil + return rtt, &tp, nil } func (h *cryptoSetup) getDataForSessionTicket() []byte { - return (&sessionTicket{ - Parameters: h.ourParams, - RTT: h.rttStats.SmoothedRTT(), - }).Marshal() + ticket := &sessionTicket{ + RTT: h.rttStats.SmoothedRTT(), + } + if h.allow0RTT { + ticket.Parameters = h.ourParams + } + return ticket.Marshal() } // GetSessionTicket generates a new session ticket. // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: h.allow0RTT, + }); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. @@ -369,22 +390,28 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { return nil, err } ev := h.conn.NextEvent() - if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication { + if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication { panic("crypto/tls bug: where's my session ticket?") } ticket := ev.Data - if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent { + if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent { panic("crypto/tls bug: why more than one ticket?") } return ticket, nil } -// accept0RTT is called for the server when receiving the client's session ticket. -// It decides whether to accept 0-RTT. -func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { +// handleSessionTicket is called for the server when receiving the client's session ticket. +// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT. +// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT: +// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT. +func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool { var t sessionTicket - if err := t.Unmarshal(sessionTicketData); err != nil { - h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) + if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil { + h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error()) + return false + } + h.rttStats.SetInitialRTT(t.RTT) + if !using0RTT { return false } valid := h.ourParams.ValidFor0RTT(t.Parameters) @@ -397,7 +424,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { return false } h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) return true } @@ -405,22 +431,19 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { func (h *cryptoSetup) rejected0RTT() { h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") - h.mutex.Lock() had0RTTKeys := h.zeroRTTSealer != nil h.zeroRTTSealer = nil - h.mutex.Unlock() if had0RTTKeys { h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) } } -func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } @@ -432,7 +455,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -440,7 +463,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { @@ -449,19 +472,17 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr default: panic("unexpected read encryption level") } - h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } -func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } @@ -469,16 +490,15 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) - h.mutex.Unlock() if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } // don't set used0RTT here. 0-RTT might still get rejected. return - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -486,7 +506,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { @@ -497,28 +517,27 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } default: panic("unexpected write encryption level") } - h.mutex.Unlock() - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } -// WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { +// writeRecord is called when TLS writes data +func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) { //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { - case qtls.QUICEncryptionLevelInitial: + case tls.QUICEncryptionLevelInitial: h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p}) - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p}) - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: panic("unexpected write") default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) @@ -526,11 +545,9 @@ func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { } func (h *cryptoSetup) DiscardInitialKeys() { - h.mutex.Lock() dropped := h.initialOpener != nil h.initialOpener = nil h.initialSealer = nil - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Initial keys.") } @@ -545,22 +562,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() { h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool - h.mutex.Lock() if h.handshakeOpener != nil { h.handshakeOpener = nil h.handshakeSealer = nil dropped = true } - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Handshake keys.") } } func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialSealer == nil { return nil, ErrKeysDropped } @@ -568,9 +580,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { } func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTSealer == nil { return nil, ErrKeysDropped } @@ -578,9 +587,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { } func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeSealer == nil { if h.initialSealer == nil { return nil, ErrKeysDropped @@ -591,9 +597,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { } func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if !h.has1RTTSealer { return nil, ErrKeysNotYetAvailable } @@ -601,9 +604,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { } func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialOpener == nil { return nil, ErrKeysDropped } @@ -611,9 +611,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { } func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -625,9 +622,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { } func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -639,13 +633,10 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { } func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } @@ -665,7 +656,7 @@ func (h *cryptoSetup) ConnectionState() ConnectionState { func wrapError(err error) error { // alert 80 is an internal error - if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { + if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { return qerr.NewLocalCryptoError(uint8(alertErr), err) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go b/vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go index fb6092e04..2c5ee42f1 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go @@ -17,14 +17,14 @@ type headerProtector interface { DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) } -func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { +func hkdfHeaderProtectionLabel(v protocol.Version) string { if v == protocol.Version2 { return "quicv2 hp" } return "quic hp" } -func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { +func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector { hkdfLabel := hkdfHeaderProtectionLabel(v) switch suite.ID { case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: @@ -37,7 +37,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b } type aesHeaderProtector struct { - mask []byte + mask [16]byte // AES always has a 16 byte block size block cipher.Block isLongHeader bool } @@ -52,7 +52,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade } return &aesHeaderProtector{ block: block, - mask: make([]byte, block.BlockSize()), isLongHeader: isLongHeader, } } @@ -69,7 +68,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by if len(sample) != len(p.mask) { panic("invalid sample size") } - p.block.Encrypt(p.mask, sample) + p.block.Encrypt(p.mask[:], sample) if p.isLongHeader { *firstByte ^= p.mask[0] & 0xf } else { diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go b/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go index c4fd86c57..0caf1c8e5 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go @@ -7,7 +7,7 @@ import ( "golang.org/x/crypto/hkdf" ) -// hkdfExpandLabel HKDF expands a label. +// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1. // Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the // hkdfExpandLabel in the standard library. func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go b/vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go index b0377c39a..b8aa7e3e7 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go @@ -21,7 +21,7 @@ const ( hkdfLabelIVV2 = "quicv2 iv" ) -func getSalt(v protocol.VersionNumber) []byte { +func getSalt(v protocol.Version) []byte { if v == protocol.Version2 { return quicSaltV2 } @@ -31,7 +31,7 @@ func getSalt(v protocol.VersionNumber) []byte { var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256) // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. -func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) { clientSecret, serverSecret := computeSecrets(connID, v) var mySecret, otherSecret []byte if pers == protocol.PerspectiveClient { @@ -51,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) } -func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { +func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) { initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) return } -func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { +func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go b/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go index 68fa53ed1..30643cdfc 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go @@ -40,7 +40,7 @@ var ( ) // GetRetryIntegrityTag calculates the integrity tag on a Retry packet -func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { +func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte { retryMutex.Lock() defer retryMutex.Unlock() diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go b/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go index d3efeb294..9481af563 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go @@ -10,7 +10,7 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -const sessionTicketRevision = 3 +const sessionTicketRevision = 4 type sessionTicket struct { Parameters *wire.TransportParameters @@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, sessionTicketRevision) b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) + if t.Parameters == nil { + return b + } return t.Parameters.MarshalForSessionTicket(b) } -func (t *sessionTicket) Unmarshal(b []byte) error { +func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error { r := bytes.NewReader(b) rev, err := quicvarint.Read(r) if err != nil { @@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error { if err != nil { return errors.New("failed to read RTT") } - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + if using0RTT { + var tp wire.TransportParameters + if err := tp.UnmarshalFromSessionTicket(r); err != nil { + return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) + } + t.Parameters = &tp + } else if r.Len() > 0 { + return fmt.Errorf("the session ticket has more bytes than expected") } - t.Parameters = &tp t.RTT = time.Duration(rtt) * time.Microsecond return nil } diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go b/vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go index e5e90bb3b..2d91e6b25 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/asn1" "fmt" - "io" "net" "time" @@ -45,15 +44,9 @@ type TokenGenerator struct { tokenProtector tokenProtector } -// NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector(rand) - if err != nil { - return nil, err - } - return &TokenGenerator{ - tokenProtector: tokenProtector, - }, nil +// NewTokenGenerator initializes a new TokenGenerator +func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator { + return &TokenGenerator{tokenProtector: newTokenProtector(key)} } // NewRetryToken generates a new token for a Retry for a given source address diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go b/vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go index 650f230b2..f3a99e411 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go @@ -3,6 +3,7 @@ package handshake import ( "crypto/aes" "crypto/cipher" + "crypto/rand" "crypto/sha256" "fmt" "io" @@ -10,6 +11,9 @@ import ( "golang.org/x/crypto/hkdf" ) +// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens. +type TokenProtectorKey [32]byte + // TokenProtector is used to create and verify a token type tokenProtector interface { // NewToken creates a new token @@ -18,40 +22,29 @@ type tokenProtector interface { DecodeToken([]byte) ([]byte, error) } -const ( - tokenSecretSize = 32 - tokenNonceSize = 32 -) +const tokenNonceSize = 32 // tokenProtector is used to create and verify a token type tokenProtectorImpl struct { - rand io.Reader - secret []byte + key TokenProtectorKey } // newTokenProtector creates a source for source address tokens -func newTokenProtector(rand io.Reader) (tokenProtector, error) { - secret := make([]byte, tokenSecretSize) - if _, err := rand.Read(secret); err != nil { - return nil, err - } - return &tokenProtectorImpl{ - rand: rand, - secret: secret, - }, nil +func newTokenProtector(key TokenProtectorKey) tokenProtector { + return &tokenProtectorImpl{key: key} } // NewToken encodes data into a new token. func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { - nonce := make([]byte, tokenNonceSize) - if _, err := s.rand.Read(nonce); err != nil { + var nonce [tokenNonceSize]byte + if _, err := rand.Read(nonce[:]); err != nil { return nil, err } - aead, aeadNonce, err := s.createAEAD(nonce) + aead, aeadNonce, err := s.createAEAD(nonce[:]) if err != nil { return nil, err } - return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil + return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil } // DecodeToken decodes a token. @@ -68,7 +61,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) { } func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { - h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) + h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source")) key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 if _, err := io.ReadFull(h, key); err != nil { return nil, nil, err diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go b/vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go index 919b8a5bc..ceaa8047a 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go @@ -57,9 +57,9 @@ type updatableAEAD struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger - version protocol.VersionNumber + version protocol.Version // use a single slice to avoid allocations nonceBuf []byte @@ -70,7 +70,7 @@ var ( _ ShortHeaderSealer = &updatableAEAD{} ) -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -86,7 +86,7 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, func (a *updatableAEAD) rollKeys() { if a.prevRcvAEAD != nil { a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } a.prevRcvAEADExpiry = time.Time{} @@ -133,7 +133,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) { // SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. -// For the server, this function is called before SetWriteKey. +// For the server, this function is called before SetReadKey. func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret, a.version) a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) @@ -172,7 +172,7 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac } } if err == nil { - a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn) + a.highestRcvdPN = max(a.highestRcvdPN, pn) } return dec, err } @@ -182,7 +182,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac a.prevRcvAEAD = nil a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.prevRcvAEADExpiry = time.Time{} - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } } @@ -216,7 +216,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac // The peer initiated this key update. It's safe to drop the keys for the previous generation now. // Start a timer to drop the previous key generation. a.startKeyDropTimer(rcvTime) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, true) } a.firstRcvdWithCurrentKey = pn @@ -308,7 +308,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, false) } } diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/params.go b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go index fe3a75625..487cbc06b 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/params.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go @@ -65,9 +65,6 @@ const MaxAcceptQueueSize = 32 // TokenValidity is the duration that a (non-retry) token is considered valid const TokenValidity = 24 * time.Hour -// RetryTokenValidity is the duration that a retry token is considered valid -const RetryTokenValidity = 10 * time.Second - // MaxOutstandingSentPackets is maximum number of packets saved for retransmission. // When reached, it imposes a soft limit on sending new packets: // Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. @@ -108,9 +105,6 @@ const DefaultIdleTimeout = 30 * time.Second // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. const DefaultHandshakeIdleTimeout = 5 * time.Second -// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds. -const DefaultHandshakeTimeout = 10 * time.Second - // MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. // It should be shorter than the time that NATs clear their mapping. const MaxKeepAliveInterval = 20 * time.Second @@ -135,13 +129,6 @@ const MaxPostHandshakeCryptoFrameSize = 1000 // but must ensure that a maximum size ACK frame fits into one packet. const MaxAckFrameSize ByteCount = 1000 -// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). -// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. -const MaxDatagramFrameSize ByteCount = 1200 - -// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) -const DatagramRcvQueueLen = 128 - // MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. // It also serves as a limit for the packet history. // If at any point we keep track of more ranges, old ranges are discarded. diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go b/vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go index 43358fecb..5a29d3ce2 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go @@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective { func (p Perspective) String() string { switch p { case PerspectiveServer: - return "Server" + return "server" case PerspectiveClient: - return "Client" + return "client" default: return "invalid perspective" } diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go b/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go index b8104882b..d056cb9de 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go @@ -37,14 +37,48 @@ func (t PacketType) String() string { type ECN uint8 const ( - ECNNon ECN = iota // 00 - ECT1 // 01 - ECT0 // 10 - ECNCE // 11 + ECNUnsupported ECN = iota + ECNNon // 00 + ECT1 // 01 + ECT0 // 10 + ECNCE // 11 ) +func ParseECNHeaderBits(bits byte) ECN { + switch bits { + case 0: + return ECNNon + case 0b00000010: + return ECT0 + case 0b00000001: + return ECT1 + case 0b00000011: + return ECNCE + default: + panic("invalid ECN bits") + } +} + +func (e ECN) ToHeaderBits() byte { + //nolint:exhaustive // There are only 4 values. + switch e { + case ECNNon: + return 0 + case ECT0: + return 0b00000010 + case ECT1: + return 0b00000001 + case ECNCE: + return 0b00000011 + default: + panic("ECN unsupported") + } +} + func (e ECN) String() string { switch e { + case ECNUnsupported: + return "ECN unsupported" case ECNNon: return "Not-ECT" case ECT1: diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/version.go b/vendor/github.com/quic-go/quic-go/internal/protocol/version.go index 5c2decbdc..025ade9b4 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/version.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/version.go @@ -1,14 +1,17 @@ package protocol import ( - "crypto/rand" "encoding/binary" "fmt" "math" + "sync" + "time" + + "golang.org/x/exp/rand" ) -// VersionNumber is a version number as int -type VersionNumber uint32 +// Version is a version number as int +type Version uint32 // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions const ( @@ -18,22 +21,22 @@ const ( // The version numbers, making grepping easier const ( - VersionUnknown VersionNumber = math.MaxUint32 - versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version - Version1 VersionNumber = 0x1 - Version2 VersionNumber = 0x6b3343cf + VersionUnknown Version = math.MaxUint32 + versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version + Version1 Version = 0x1 + Version2 Version = 0x6b3343cf ) // SupportedVersions lists the versions that the server supports // must be in sorted descending order -var SupportedVersions = []VersionNumber{Version1, Version2} +var SupportedVersions = []Version{Version1, Version2} // IsValidVersion says if the version is known to quic-go -func IsValidVersion(v VersionNumber) bool { +func IsValidVersion(v Version) bool { return v == Version1 || IsSupportedVersion(SupportedVersions, v) } -func (vn VersionNumber) String() string { +func (vn Version) String() string { //nolint:exhaustive switch vn { case VersionUnknown: @@ -52,16 +55,16 @@ func (vn VersionNumber) String() string { } } -func (vn VersionNumber) isGQUIC() bool { +func (vn Version) isGQUIC() bool { return vn > gquicVersion0 && vn <= maxGquicVersion } -func (vn VersionNumber) toGQUICVersion() int { +func (vn Version) toGQUICVersion() int { return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) } // IsSupportedVersion returns true if the server supports this version -func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { +func IsSupportedVersion(supported []Version, v Version) bool { for _, t := range supported { if t == v { return true @@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { // ours is a slice of versions that we support, sorted by our preference (descending) // theirs is a slice of versions offered by the peer. The order does not matter. // The bool returned indicates if a matching version was found. -func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { +func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) { for _, ourVer := range ours { for _, theirVer := range theirs { if ourVer == theirVer { @@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) return 0, false } -// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) -func generateReservedVersion() VersionNumber { - b := make([]byte, 4) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +var ( + versionNegotiationMx sync.Mutex + versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano()))) +) + +// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() Version { + var b [4]byte + _, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything + return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa) } -// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position -func GetGreasedVersions(supported []VersionNumber) []VersionNumber { - b := make([]byte, 1) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - randPos := int(b[0]) % (len(supported) + 1) - greased := make([]VersionNumber, len(supported)+1) +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position. +// It doesn't modify the supported slice. +func GetGreasedVersions(supported []Version) []Version { + versionNegotiationMx.Lock() + defer versionNegotiationMx.Unlock() + randPos := rand.Intn(len(supported) + 1) + greased := make([]Version, len(supported)+1) copy(greased, supported[:randPos]) greased[randPos] = generateReservedVersion() copy(greased[randPos+1:], supported[randPos:]) diff --git a/vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go b/vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go index a037acd22..00361308e 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go +++ b/vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go @@ -1,9 +1,8 @@ package qerr import ( + "crypto/tls" "fmt" - - "github.com/quic-go/quic-go/internal/qtls" ) // TransportErrorCode is a QUIC transport error. @@ -40,7 +39,7 @@ func (e TransportErrorCode) Message() string { if !e.IsCryptoError() { return "" } - return qtls.AlertError(e - 0x100).Error() + return tls.AlertError(e - 0x100).Error() } func (e TransportErrorCode) String() string { diff --git a/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go b/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go index 2d8511f77..8f5936dfa 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go +++ b/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go @@ -101,8 +101,8 @@ func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.Err // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. type VersionNegotiationError struct { - Ours []protocol.VersionNumber - Theirs []protocol.VersionNumber + Ours []protocol.Version + Theirs []protocol.Version } func (e *VersionNegotiationError) Error() string { diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite_go121.go b/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite.go similarity index 85% rename from vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite_go121.go rename to vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite.go index aa8c768fd..32a921cd5 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite_go121.go +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/cipher_suite.go @@ -1,25 +1,11 @@ -//go:build go1.21 - package qtls import ( - "crypto" - "crypto/cipher" "crypto/tls" "fmt" "unsafe" ) -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - //go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 var cipherSuitesTLS13 []unsafe.Pointer diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go b/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go index 519895ee6..4acac9e2e 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/client_session_cache.go @@ -1,20 +1,23 @@ -//go:build go1.21 - package qtls import ( "crypto/tls" + "sync" ) type clientSessionCache struct { - getData func() []byte - setData func([]byte) + mx sync.Mutex + getData func(earlyData bool) []byte + setData func(data []byte, earlyData bool) (allowEarlyData bool) wrapped tls.ClientSessionCache } var _ tls.ClientSessionCache = &clientSessionCache{} -func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { +func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) { + c.mx.Lock() + defer c.mx.Unlock() + if cs == nil { c.wrapped.Put(key, nil) return @@ -24,7 +27,7 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { c.wrapped.Put(key, cs) return } - state.Extra = append(state.Extra, addExtraPrefix(c.getData())) + state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData))) newCS, err := tls.NewResumptionState(ticket, state) if err != nil { // It's not clear why this would error. Just save the original state. @@ -34,7 +37,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { c.wrapped.Put(key, newCS) } -func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { +func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { + c.mx.Lock() + defer c.mx.Unlock() + cs, ok := c.wrapped.Get(key) if !ok || cs == nil { return cs, ok @@ -48,7 +54,10 @@ func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { } // restore QUIC transport parameters and RTT stored in state.Extra if extra := findExtraData(state.Extra); extra != nil { - c.setData(extra) + earlyData := c.setData(extra, state.EarlyData) + if state.EarlyData { + state.EarlyData = earlyData + } } session, err := tls.NewResumptionState(ticket, state) if err != nil { diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/go120.go b/vendor/github.com/quic-go/quic-go/internal/qtls/go120.go deleted file mode 100644 index 3b50441c4..000000000 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/go120.go +++ /dev/null @@ -1,145 +0,0 @@ -//go:build go1.20 && !go1.21 - -package qtls - -import ( - "crypto/tls" - "fmt" - "unsafe" - - "github.com/quic-go/quic-go/internal/protocol" - - "github.com/quic-go/qtls-go1-20" -) - -type ( - QUICConn = qtls.QUICConn - QUICConfig = qtls.QUICConfig - QUICEvent = qtls.QUICEvent - QUICEventKind = qtls.QUICEventKind - QUICEncryptionLevel = qtls.QUICEncryptionLevel - AlertError = qtls.AlertError -) - -const ( - QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial - QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly - QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake - QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication -) - -const ( - QUICNoEvent = qtls.QUICNoEvent - QUICSetReadSecret = qtls.QUICSetReadSecret - QUICSetWriteSecret = qtls.QUICSetWriteSecret - QUICWriteData = qtls.QUICWriteData - QUICTransportParameters = qtls.QUICTransportParameters - QUICTransportParametersRequired = qtls.QUICTransportParametersRequired - QUICRejectedEarlyData = qtls.QUICRejectedEarlyData - QUICHandshakeDone = qtls.QUICHandshakeDone -) - -func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) { - qtls.InitSessionTicketKeys(conf.TLSConfig) - conf.TLSConfig = conf.TLSConfig.Clone() - conf.TLSConfig.MinVersion = tls.VersionTLS13 - conf.ExtraConfig = &qtls.ExtraConfig{ - Enable0RTT: enable0RTT, - Accept0RTT: accept0RTT, - GetAppDataForSessionTicket: getDataForSessionTicket, - } -} - -func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) { - conf.ExtraConfig = &qtls.ExtraConfig{ - GetAppDataForSessionState: getDataForSessionState, - SetAppDataFromSessionState: setDataFromSessionState, - } -} - -func QUICServer(config *QUICConfig) *QUICConn { - return qtls.QUICServer(config) -} - -func QUICClient(config *QUICConfig) *QUICConn { - return qtls.QUICClient(config) -} - -func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel { - switch e { - case protocol.EncryptionInitial: - return qtls.QUICEncryptionLevelInitial - case protocol.EncryptionHandshake: - return qtls.QUICEncryptionLevelHandshake - case protocol.Encryption1RTT: - return qtls.QUICEncryptionLevelApplication - case protocol.Encryption0RTT: - return qtls.QUICEncryptionLevelEarly - default: - panic(fmt.Sprintf("unexpected encryption level: %s", e)) - } -} - -func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel { - switch e { - case qtls.QUICEncryptionLevelInitial: - return protocol.EncryptionInitial - case qtls.QUICEncryptionLevelHandshake: - return protocol.EncryptionHandshake - case qtls.QUICEncryptionLevelApplication: - return protocol.Encryption1RTT - case qtls.QUICEncryptionLevelEarly: - return protocol.Encryption0RTT - default: - panic(fmt.Sprintf("unexpect encryption level: %s", e)) - } -} - -//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-20.cipherSuitesTLS13 -var cipherSuitesTLS13 []unsafe.Pointer - -//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13 -var defaultCipherSuitesTLS13 []uint16 - -//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13NoAES -var defaultCipherSuitesTLS13NoAES []uint16 - -var cipherSuitesModified bool - -// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls -// such that it only contains the cipher suite with the chosen id. -// The reset function returned resets them back to the original value. -func SetCipherSuite(id uint16) (reset func()) { - if cipherSuitesModified { - panic("cipher suites modified multiple times without resetting") - } - cipherSuitesModified = true - - origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) - origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) - origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) - // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. - switch id { - case tls.TLS_AES_128_GCM_SHA256: - cipherSuitesTLS13 = cipherSuitesTLS13[:1] - case tls.TLS_CHACHA20_POLY1305_SHA256: - cipherSuitesTLS13 = cipherSuitesTLS13[1:2] - case tls.TLS_AES_256_GCM_SHA384: - cipherSuitesTLS13 = cipherSuitesTLS13[2:] - default: - panic(fmt.Sprintf("unexpected cipher suite: %d", id)) - } - defaultCipherSuitesTLS13 = []uint16{id} - defaultCipherSuitesTLS13NoAES = []uint16{id} - - return func() { - cipherSuitesTLS13 = origCipherSuitesTLS13 - defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 - defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES - cipherSuitesModified = false - } -} - -func SendSessionTicket(c *QUICConn, allow0RTT bool) error { - return c.SendSessionTicket(allow0RTT) -} diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go b/vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go deleted file mode 100644 index 0fca80a38..000000000 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go +++ /dev/null @@ -1,5 +0,0 @@ -//go:build !go1.20 - -package qtls - -var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/go121.go b/vendor/github.com/quic-go/quic-go/internal/qtls/qtls.go similarity index 58% rename from vendor/github.com/quic-go/quic-go/internal/qtls/go121.go rename to vendor/github.com/quic-go/quic-go/internal/qtls/qtls.go index 4aebc4a14..ebcd9d4de 100644 --- a/vendor/github.com/quic-go/quic-go/internal/qtls/go121.go +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/qtls.go @@ -1,5 +1,3 @@ -//go:build go1.21 - package qtls import ( @@ -10,38 +8,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" ) -type ( - QUICConn = tls.QUICConn - QUICConfig = tls.QUICConfig - QUICEvent = tls.QUICEvent - QUICEventKind = tls.QUICEventKind - QUICEncryptionLevel = tls.QUICEncryptionLevel - QUICSessionTicketOptions = tls.QUICSessionTicketOptions - AlertError = tls.AlertError -) - -const ( - QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial - QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly - QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake - QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication -) - -const ( - QUICNoEvent = tls.QUICNoEvent - QUICSetReadSecret = tls.QUICSetReadSecret - QUICSetWriteSecret = tls.QUICSetWriteSecret - QUICWriteData = tls.QUICWriteData - QUICTransportParameters = tls.QUICTransportParameters - QUICTransportParametersRequired = tls.QUICTransportParametersRequired - QUICRejectedEarlyData = tls.QUICRejectedEarlyData - QUICHandshakeDone = tls.QUICHandshakeDone -) - -func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) } -func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) } - -func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { +func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { conf := qconf.TLSConfig // Workaround for https://github.com/golang/go/issues/60506. @@ -55,11 +22,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce // add callbacks to save transport parameters into the session ticket origWrapSession := conf.WrapSession conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { - // Add QUIC transport parameters if this is a 0-RTT packet. - // TODO(#3853): also save the RTT for non-0-RTT tickets - if state.EarlyData { - state.Extra = append(state.Extra, addExtraPrefix(getData())) - } + // Add QUIC session ticket + state.Extra = append(state.Extra, addExtraPrefix(getData())) + if origWrapSession != nil { return origWrapSession(cs, state) } @@ -83,19 +48,23 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce if err != nil || state == nil { return nil, err } - if state.EarlyData { - extra := findExtraData(state.Extra) - if unwrapCount == 1 && extra != nil { // first session ticket - state.EarlyData = accept0RTT(extra) - } else { // subsequent session ticket, can't be used for 0-RTT - state.EarlyData = false - } + + extra := findExtraData(state.Extra) + if extra != nil { + state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1) + } else { + state.EarlyData = false } + return state, nil } } -func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) { +func SetupConfigForClient( + qconf *tls.QUICConfig, + getData func(earlyData bool) []byte, + setData func(data []byte, earlyData bool) (allowEarlyData bool), +) { conf := qconf.TLSConfig if conf.ClientSessionCache != nil { origCache := conf.ClientSessionCache @@ -153,9 +122,3 @@ func findExtraData(extras [][]byte) []byte { } return nil } - -func SendSessionTicket(c *QUICConn, allow0RTT bool) error { - return c.SendSessionTicket(tls.QUICSessionTicketOptions{ - EarlyData: allow0RTT, - }) -} diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go b/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go index d191f7515..03a9c9a87 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go @@ -3,27 +3,11 @@ package utils import ( "math" "time" - - "golang.org/x/exp/constraints" ) // InfDuration is a duration of infinite length const InfDuration = time.Duration(math.MaxInt64) -func Max[T constraints.Ordered](a, b T) T { - if a < b { - return b - } - return a -} - -func Min[T constraints.Ordered](a, b T) T { - if a < b { - return a - } - return b -} - // MinNonZeroDuration return the minimum duration that's not zero. func MinNonZeroDuration(a, b time.Duration) time.Duration { if a == 0 { @@ -32,15 +16,7 @@ func MinNonZeroDuration(a, b time.Duration) time.Duration { if b == 0 { return a } - return Min(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d + return min(a, b) } // MinTime returns the earlier time @@ -51,18 +27,6 @@ func MinTime(a, b time.Time) time.Time { return a } -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - // MaxTime returns the later time func MaxTime(a, b time.Time) time.Time { if a.After(b) { diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go b/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go index 81a5ad44b..f9b2c797b 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go @@ -8,7 +8,7 @@ type RingBuffer[T any] struct { full bool } -// Init preallocs a buffer with a certain size. +// Init preallocates a buffer with a certain size. func (r *RingBuffer[T]) Init(size int) { r.ring = make([]T, size) } @@ -62,6 +62,16 @@ func (r *RingBuffer[T]) PopFront() T { return t } +// PeekFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PeekFront() T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: peek from an empty queue") + } + return r.ring[r.headPos] +} + // Grow the maximum size of the queue. // This method assume the queue is full. func (r *RingBuffer[T]) grow() { diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go index 2cd9a1919..463b95424 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go @@ -55,7 +55,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { if r.SmoothedRTT() == 0 { return 2 * defaultInitialRTT } - pto := r.SmoothedRTT() + Max(4*r.MeanDeviation(), protocol.TimerGranularity) + pto := r.SmoothedRTT() + max(4*r.MeanDeviation(), protocol.TimerGranularity) if includeMaxAckDelay { pto += r.MaxAckDelay() } @@ -90,7 +90,7 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { r.smoothedRTT = sample r.meanDeviation = sample / 2 } else { - r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond + r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32((r.smoothedRTT-sample).Abs()/time.Microsecond)) * time.Microsecond r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond } } @@ -126,6 +126,6 @@ func (r *RTTStats) OnConnectionMigration() { // is larger. The mean deviation is increased to the most recent deviation if // it's larger. func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = Max(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) - r.smoothedRTT = Max(r.smoothedRTT, r.latestRTT) + r.meanDeviation = max(r.meanDeviation, (r.smoothedRTT - r.latestRTT).Abs()) + r.smoothedRTT = max(r.smoothedRTT, r.latestRTT) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go index 9b23cc25f..a0f3feb06 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go @@ -22,7 +22,7 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error { +func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error { ecn := typ == ackECNFrameType la, err := quicvarint.Read(r) @@ -37,7 +37,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen delayTime := time.Duration(delay*1< 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { b = append(b, ackECNFrameType) @@ -143,7 +143,7 @@ func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { } // Length of a written frame -func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount { largestAcked := f.AckRanges[0].Largest numRanges := f.numEncodableAckRanges() diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go index f56c2c0d8..df3624474 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go @@ -16,7 +16,7 @@ type ConnectionCloseFrame struct { ReasonPhrase string } -func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { +func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) { f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType} ec, err := quicvarint.Read(r) if err != nil { @@ -53,7 +53,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNu } // Length of a written frame -func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount { length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) if !f.IsApplicationError { length += quicvarint.Len(f.FrameType) // for the frame type @@ -61,7 +61,7 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount return length } -func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if f.IsApplicationError { b = append(b, applicationCloseFrameType) } else { diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go index 0f005c5ba..d42146391 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go @@ -14,7 +14,7 @@ type CryptoFrame struct { Data []byte } -func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { +func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) { frame := &CryptoFrame{} offset, err := quicvarint.Read(r) if err != nil { @@ -38,7 +38,7 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, return frame, nil } -func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, cryptoFrameType) b = quicvarint.Append(b, uint64(f.Offset)) b = quicvarint.Append(b, uint64(len(f.Data))) @@ -47,7 +47,7 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) } // Length of a written frame -func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) } @@ -71,7 +71,7 @@ func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount // The frame might not be split if: // * the size is large enough to fit the whole frame // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { +func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) { if f.Length(version) <= maxSize { return nil, false } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go index 0d4d1f565..8fe2acb54 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go @@ -12,7 +12,7 @@ type DataBlockedFrame struct { MaximumData protocol.ByteCount } -func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { +func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) { offset, err := quicvarint.Read(r) if err != nil { return nil, err @@ -20,12 +20,12 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil } -func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { +func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) { b = append(b, dataBlockedFrameType) return quicvarint.Append(b, uint64(f.MaximumData)), nil } // Length of a written frame -func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaximumData)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go index e6c451961..8e406f1ad 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go @@ -8,13 +8,19 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) +// MaxDatagramSize is the maximum size of a DATAGRAM frame (RFC 9221). +// By setting it to a large value, we allow all datagrams that fit into a QUIC packet. +// The value is chosen such that it can still be encoded as a 2 byte varint. +// This is a var and not a const so it can be set in tests. +var MaxDatagramSize protocol.ByteCount = 16383 + // A DatagramFrame is a DATAGRAM frame type DatagramFrame struct { DataLenPresent bool Data []byte } -func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) { +func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) { f := &DatagramFrame{} f.DataLenPresent = typ&0x1 > 0 @@ -39,7 +45,7 @@ func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) ( return f, nil } -func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { typ := uint8(0x30) if f.DataLenPresent { typ ^= 0b1 @@ -53,7 +59,7 @@ func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro } // MaxDataLen returns the maximum data length -func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { +func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount { headerLen := protocol.ByteCount(1) if f.DataLenPresent { // pretend that the data size will be 1 bytes @@ -71,7 +77,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol. } // Length of a written frame -func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount { length := 1 + protocol.ByteCount(len(f.Data)) if f.DataLenPresent { length += quicvarint.Len(uint64(len(f.Data))) diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go b/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go index d10820d6d..e04d91b78 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go @@ -32,7 +32,7 @@ type ExtendedHeader struct { parsedLen protocol.ByteCount } -func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) { startLen := b.Len() // read the (now unencrypted) first byte var err error @@ -51,7 +51,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool return reservedBitsValid, err } -func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) { if err := h.readPacketNumber(b); err != nil { return false, err } @@ -95,7 +95,7 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { } // Append appends the Header. -func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) { +func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) { if h.DestConnectionID.Len() > protocol.MaxConnIDLen { return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) } @@ -162,7 +162,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { } // GetLength determines the length of the Header. -func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount { +func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount { length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ if h.Type == protocol.PacketTypeInitial { length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go b/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go index ff35dd101..cf7d4cecd 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go @@ -36,7 +36,8 @@ const ( handshakeDoneFrameType = 0x1e ) -type frameParser struct { +// The FrameParser parses QUIC frames, one by one. +type FrameParser struct { r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them ackDelayExponent uint8 @@ -47,11 +48,9 @@ type frameParser struct { ackFrame *AckFrame } -var _ FrameParser = &frameParser{} - // NewFrameParser creates a new frame parser. -func NewFrameParser(supportsDatagrams bool) *frameParser { - return &frameParser{ +func NewFrameParser(supportsDatagrams bool) *FrameParser { + return &FrameParser{ r: *bytes.NewReader(nil), supportsDatagrams: supportsDatagrams, ackFrame: &AckFrame{}, @@ -60,7 +59,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser { // ParseNext parses the next frame. // It skips PADDING frames. -func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) { +func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) { startLen := len(data) p.r.Reset(data) frame, err := p.parseNext(&p.r, encLevel, v) @@ -69,7 +68,7 @@ func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, return n, frame, err } -func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { +func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) { for r.Len() != 0 { typ, err := quicvarint.Read(r) if err != nil { @@ -95,7 +94,7 @@ func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLev return nil, nil } -func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { +func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) { var frame Frame var err error if typ&0xf8 == 0x8 { @@ -163,7 +162,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol. return frame, nil } -func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { +func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { switch encLevel { case protocol.EncryptionInitial, protocol.EncryptionHandshake: switch f.(type) { @@ -186,6 +185,8 @@ func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL } } -func (p *frameParser) SetAckDelayExponent(exp uint8) { +// SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters). +// This value is used to scale the ACK Delay field in the ACK frame. +func (p *FrameParser) SetAckDelayExponent(exp uint8) { p.ackDelayExponent = exp } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go index 29521bc9e..85dd64745 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go @@ -7,11 +7,11 @@ import ( // A HandshakeDoneFrame is a HANDSHAKE_DONE frame type HandshakeDoneFrame struct{} -func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return append(b, handshakeDoneFrameType), nil } // Length of a written frame -func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/header.go b/vendor/github.com/quic-go/quic-go/internal/wire/header.go index 0c60f4dd9..299116849 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/header.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/header.go @@ -85,11 +85,11 @@ func IsLongHeaderPacket(firstByte byte) bool { // ParseVersion parses the QUIC version. // It should only be called for Long Header packets (Short Header packets don't contain a version number). -func ParseVersion(data []byte) (protocol.VersionNumber, error) { +func ParseVersion(data []byte) (protocol.Version, error) { if len(data) < 5 { return 0, io.EOF } - return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil + return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil } // IsVersionNegotiationPacket says if this is a version negotiation packet @@ -109,7 +109,7 @@ func Is0RTTPacket(b []byte) bool { if !IsLongHeaderPacket(b[0]) { return false } - version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) + version := protocol.Version(binary.BigEndian.Uint32(b[1:5])) //nolint:exhaustive // We only need to test QUIC versions that we support. switch version { case protocol.Version1: @@ -128,7 +128,7 @@ type Header struct { typeByte byte Type protocol.PacketType - Version protocol.VersionNumber + Version protocol.Version SrcConnectionID protocol.ConnectionID DestConnectionID protocol.ConnectionID @@ -184,7 +184,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { if err != nil { return err } - h.Version = protocol.VersionNumber(v) + h.Version = protocol.Version(v) if h.Version != 0 && h.typeByte&0x40 == 0 { return errors.New("not a QUIC packet") } @@ -278,7 +278,7 @@ func (h *Header) ParsedLen() protocol.ByteCount { // ParseExtended parses the version dependent part of the header. // The Reader has to be set such that it points to the first byte of the header. -func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { +func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) { extHdr := h.toExtendedHeader() reservedBitsValid, err := extHdr.parse(b, ver) if err != nil { diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/interface.go b/vendor/github.com/quic-go/quic-go/internal/wire/interface.go index 7e0f9a03e..bc17883b5 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/interface.go @@ -6,12 +6,6 @@ import ( // A Frame in QUIC type Frame interface { - Append(b []byte, version protocol.VersionNumber) ([]byte, error) - Length(version protocol.VersionNumber) protocol.ByteCount -} - -// A FrameParser parses QUIC frames, one by one. -type FrameParser interface { - ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error) - SetAckDelayExponent(uint8) + Append(b []byte, version protocol.Version) ([]byte, error) + Length(version protocol.Version) protocol.ByteCount } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/log.go b/vendor/github.com/quic-go/quic-go/internal/wire/log.go index ec7d45d86..c8b28d924 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/log.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/log.go @@ -63,7 +63,9 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) { logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) } case *NewConnectionIDFrame: - logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) + logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken) + case *RetireConnectionIDFrame: + logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber) case *NewTokenFrame: logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) default: diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go index e61b0f9f3..3dfd76116 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go @@ -13,7 +13,7 @@ type MaxDataFrame struct { } // parseMaxDataFrame parses a MAX_DATA frame -func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { +func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) { frame := &MaxDataFrame{} byteOffset, err := quicvarint.Read(r) if err != nil { @@ -23,13 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame return frame, nil } -func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, maxDataFrameType) b = quicvarint.Append(b, uint64(f.MaximumData)) return b, nil } // Length of a written frame -func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaximumData)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go index fe3d1e3f7..cb5eab1b0 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go @@ -13,7 +13,7 @@ type MaxStreamDataFrame struct { MaximumStreamData protocol.ByteCount } -func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { +func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) { sid, err := quicvarint.Read(r) if err != nil { return nil, err @@ -29,7 +29,7 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr }, nil } -func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { +func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) { b = append(b, maxStreamDataFrameType) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) @@ -37,6 +37,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([ } // Length of a written frame -func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go index bd278c02a..d90293383 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go @@ -14,7 +14,7 @@ type MaxStreamsFrame struct { MaxStreamNum protocol.StreamNum } -func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { +func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) { f := &MaxStreamsFrame{} switch typ { case bidiMaxStreamsFrameType: @@ -33,7 +33,7 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) return f, nil } -func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: b = append(b, bidiMaxStreamsFrameType) @@ -45,6 +45,6 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, er } // Length of a written frame -func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go index 83102d5d1..afae010ad 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "errors" "fmt" "io" @@ -17,7 +18,7 @@ type NewConnectionIDFrame struct { StatelessResetToken protocol.StatelessResetToken } -func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { +func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) { seq, err := quicvarint.Read(r) if err != nil { return nil, err @@ -34,6 +35,9 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC if err != nil { return nil, err } + if connIDLen == 0 { + return nil, errors.New("invalid zero-length connection ID") + } connID, err := protocol.ReadConnectionID(r, int(connIDLen)) if err != nil { return nil, err @@ -53,7 +57,7 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC return frame, nil } -func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, newConnectionIDFrameType) b = quicvarint.Append(b, f.SequenceNumber) b = quicvarint.Append(b, f.RetirePriorTo) @@ -68,6 +72,6 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byt } // Length of a written frame -func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go index c3fa178c1..6a2eac945 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go @@ -14,7 +14,7 @@ type NewTokenFrame struct { Token []byte } -func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { +func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) { tokenLen, err := quicvarint.Read(r) if err != nil { return nil, err @@ -32,7 +32,7 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra return &NewTokenFrame{Token: token}, nil } -func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, newTokenFrameType) b = quicvarint.Append(b, uint64(len(f.Token))) b = append(b, f.Token...) @@ -40,6 +40,6 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro } // Length of a written frame -func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go index ad024330a..772041ac6 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go @@ -12,7 +12,7 @@ type PathChallengeFrame struct { Data [8]byte } -func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { +func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) { frame := &PathChallengeFrame{} if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if err == io.ErrUnexpectedEOF { @@ -23,13 +23,13 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh return frame, nil } -func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, pathChallengeFrameType) b = append(b, f.Data[:]...) return b, nil } // Length of a written frame -func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + 8 } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go index 76e651049..86bbe619f 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go @@ -12,7 +12,7 @@ type PathResponseFrame struct { Data [8]byte } -func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { +func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) { frame := &PathResponseFrame{} if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if err == io.ErrUnexpectedEOF { @@ -23,13 +23,13 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes return frame, nil } -func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, pathResponseFrameType) b = append(b, f.Data[:]...) return b, nil } // Length of a written frame -func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + 8 } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go index dd24edc0c..71f8d16c3 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go @@ -7,11 +7,11 @@ import ( // A PingFrame is a PING frame type PingFrame struct{} -func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return append(b, pingFrameType), nil } // Length of a written frame -func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go index cd94c9408..e60f1db12 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go @@ -15,7 +15,7 @@ type ResetStreamFrame struct { FinalSize protocol.ByteCount } -func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { +func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) { var streamID protocol.StreamID var byteOffset protocol.ByteCount sid, err := quicvarint.Read(r) @@ -40,7 +40,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr }, nil } -func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, resetStreamFrameType) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) @@ -49,6 +49,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, e } // Length of a written frame -func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go index 8e9a41d89..981536224 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go @@ -12,7 +12,7 @@ type RetireConnectionIDFrame struct { SequenceNumber uint64 } -func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { +func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) { seq, err := quicvarint.Read(r) if err != nil { return nil, err @@ -20,13 +20,13 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R return &RetireConnectionIDFrame{SequenceNumber: seq}, nil } -func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, retireConnectionIDFrameType) b = quicvarint.Append(b, f.SequenceNumber) return b, nil } // Length of a written frame -func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(f.SequenceNumber) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go index d7b8b2402..d314a5698 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go @@ -15,7 +15,7 @@ type StopSendingFrame struct { } // parseStopSendingFrame parses a STOP_SENDING frame -func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { +func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) { streamID, err := quicvarint.Read(r) if err != nil { return nil, err @@ -32,11 +32,11 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend } // Length of a written frame -func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) } -func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, stopSendingFrameType) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go index d42e59a2d..f79740f98 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go @@ -13,7 +13,7 @@ type StreamDataBlockedFrame struct { MaximumStreamData protocol.ByteCount } -func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { +func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) { sid, err := quicvarint.Read(r) if err != nil { return nil, err @@ -29,7 +29,7 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St }, nil } -func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, 0x15) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) @@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]b } // Length of a written frame -func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go index d22e1c052..0f6c00da2 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go @@ -20,7 +20,7 @@ type StreamFrame struct { fromPool bool } -func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) { +func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) { hasOffset := typ&0b100 > 0 fin := typ&0b1 > 0 hasDataLen := typ&0b10 > 0 @@ -79,7 +79,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*S } // Write writes a STREAM frame -func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if len(f.Data) == 0 && !f.Fin { return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") } @@ -108,7 +108,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) } // Length returns the total length of the STREAM frame -func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount { length := 1 + quicvarint.Len(uint64(f.StreamID)) if f.Offset != 0 { length += quicvarint.Len(uint64(f.Offset)) @@ -126,7 +126,7 @@ func (f *StreamFrame) DataLen() protocol.ByteCount { // MaxDataLen returns the maximum data length // If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). -func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount { headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) if f.Offset != 0 { headerLen += quicvarint.Len(uint64(f.Offset)) @@ -151,7 +151,7 @@ func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Ve // The frame might not be split if: // * the size is large enough to fit the whole frame // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) { +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) { if maxSize >= f.Length(version) { return nil, false } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go index 4a5951c62..b24619ab0 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go @@ -14,7 +14,7 @@ type StreamsBlockedFrame struct { StreamLimit protocol.StreamNum } -func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { +func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) { f := &StreamsBlockedFrame{} switch typ { case bidiStreamBlockedFrameType: @@ -33,7 +33,7 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNum return f, nil } -func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: b = append(b, bidiStreamBlockedFrameType) @@ -45,6 +45,6 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte } // Length of a written frame -func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamLimit)) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go b/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go index 7226521b0..c03be3cd7 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "sort" "time" @@ -51,10 +51,7 @@ const ( // PreferredAddress is the value encoding in the preferred_address transport parameter type PreferredAddress struct { - IPv4 net.IP - IPv4Port uint16 - IPv6 net.IP - IPv6Port uint16 + IPv4, IPv6 netip.AddrPort ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } @@ -218,26 +215,24 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { remainingLen := r.Len() pa := &PreferredAddress{} - ipv4 := make([]byte, 4) - if _, err := io.ReadFull(r, ipv4); err != nil { + var ipv4 [4]byte + if _, err := io.ReadFull(r, ipv4[:]); err != nil { return err } - pa.IPv4 = net.IP(ipv4) port, err := utils.BigEndian.ReadUint16(r) if err != nil { return err } - pa.IPv4Port = port - ipv6 := make([]byte, 16) - if _, err := io.ReadFull(r, ipv6); err != nil { + pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port) + var ipv6 [16]byte + if _, err := io.ReadFull(r, ipv6[:]); err != nil { return err } - pa.IPv6 = net.IP(ipv6) port, err = utils.BigEndian.ReadUint16(r) if err != nil { return err } - pa.IPv6Port = port + pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port) connIDLen, err := r.ReadByte() if err != nil { return err @@ -294,7 +289,7 @@ func (p *TransportParameters) readNumericTransportParameter( return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) } case maxIdleTimeoutParameterID: - p.MaxIdleTimeout = utils.Max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) + p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) case maxUDPPayloadSizeParameterID: if val < 1200 { return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) @@ -384,13 +379,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { if p.PreferredAddress != nil { b = quicvarint.Append(b, uint64(preferredAddressParameterID)) b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) - ipv4 := p.PreferredAddress.IPv4 - b = append(b, ipv4[len(ipv4)-4:]...) - b = append(b, []byte{0, 0}...) - binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port) - b = append(b, p.PreferredAddress.IPv6...) - b = append(b, []byte{0, 0}...) - binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port) + ip4 := p.PreferredAddress.IPv4.Addr().As4() + b = append(b, ip4[:]...) + b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port()) + ip6 := p.PreferredAddress.IPv6.Addr().As16() + b = append(b, ip6[:]...) + b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port()) b = append(b, uint8(p.PreferredAddress.ConnectionID.Len())) b = append(b, p.PreferredAddress.ConnectionID.Bytes()...) b = append(b, p.PreferredAddress.StatelessResetToken[:]...) diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go index afde70fa4..a551aa8c9 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go @@ -1,17 +1,15 @@ package wire import ( - "bytes" "crypto/rand" "encoding/binary" "errors" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" ) // ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) { +func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) { n, dest, src, err := ParseArbitraryLenConnectionIDs(b) if err != nil { return nil, nil, nil, err @@ -25,32 +23,31 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon //nolint:stylecheck return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") } - versions := make([]protocol.VersionNumber, len(b)/4) + versions := make([]protocol.Version, len(b)/4) for i := 0; len(b) > 0; i++ { - versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4])) + versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4])) b = b[4:] } return dest, src, versions, nil } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 - buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) - r := make([]byte, 1) - _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + buf := make([]byte, 1+4 /* type byte and version field */, expectedLen) + _, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here. // Setting the "QUIC bit" (0x40) is not required by the RFC, // but it allows clients to demultiplex QUIC with a long list of other protocols. // See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details. - buf.WriteByte(r[0] | 0xc0) - utils.BigEndian.WriteUint32(buf, 0) // version 0 - buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID.Bytes()) - buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID.Bytes()) + buf[0] |= 0xc0 + // The next 4 bytes are left at 0 (version number). + buf = append(buf, uint8(destConnID.Len())) + buf = append(buf, destConnID.Bytes()...) + buf = append(buf, uint8(srcConnID.Len())) + buf = append(buf, srcConnID.Bytes()...) for _, v := range greasedVersions { - utils.BigEndian.WriteUint32(buf, uint32(v)) + buf = binary.BigEndian.AppendUint32(buf, uint32(v)) } - return buf.Bytes() + return buf } diff --git a/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go b/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go new file mode 100644 index 000000000..7f54d6cda --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/connection_tracer.go @@ -0,0 +1,263 @@ +package logging + +import ( + "net" + "time" +) + +// A ConnectionTracer records events. +type ConnectionTracer struct { + StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection func(error) + SentTransportParameters func(*TransportParameters) + ReceivedTransportParameters func(*TransportParameters) + RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT + SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) + SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) + ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber) + ReceivedRetry func(*Header) + ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) + ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) + BufferedPacket func(PacketType, ByteCount) + DroppedPacket func(PacketType, PacketNumber, ByteCount, PacketDropReason) + UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket func(EncryptionLevel, PacketNumber) + LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState func(CongestionState) + UpdatedPTOCount func(value uint32) + UpdatedKeyFromTLS func(EncryptionLevel, Perspective) + UpdatedKey func(keyPhase KeyPhase, remote bool) + DroppedEncryptionLevel func(EncryptionLevel) + DroppedKey func(keyPhase KeyPhase) + SetLossTimer func(TimerType, EncryptionLevel, time.Time) + LossTimerExpired func(TimerType, EncryptionLevel) + LossTimerCanceled func() + ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) + ChoseALPN func(protocol string) + // Close is called when the connection is closed. + Close func() + Debug func(name, msg string) +} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range tracers { + if t.StartedConnection != nil { + t.StartedConnection(local, remote, srcConnID, destConnID) + } + } + }, + NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range tracers { + if t.NegotiatedVersion != nil { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } + } + }, + ClosedConnection: func(e error) { + for _, t := range tracers { + if t.ClosedConnection != nil { + t.ClosedConnection(e) + } + } + }, + SentTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.SentTransportParameters != nil { + t.SentTransportParameters(tp) + } + } + }, + ReceivedTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.ReceivedTransportParameters != nil { + t.ReceivedTransportParameters(tp) + } + } + }, + RestoredTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.RestoredTransportParameters != nil { + t.RestoredTransportParameters(tp) + } + } + }, + SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentLongHeaderPacket != nil { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentShortHeaderPacket != nil { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.ReceivedVersionNegotiationPacket != nil { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + } + } + }, + ReceivedRetry: func(hdr *Header) { + for _, t := range tracers { + if t.ReceivedRetry != nil { + t.ReceivedRetry(hdr) + } + } + }, + ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedLongHeaderPacket != nil { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + } + } + }, + ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedShortHeaderPacket != nil { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + } + } + }, + BufferedPacket: func(typ PacketType, size ByteCount) { + for _, t := range tracers { + if t.BufferedPacket != nil { + t.BufferedPacket(typ, size) + } + } + }, + DroppedPacket: func(typ PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(typ, pn, size, reason) + } + } + }, + UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { + for _, t := range tracers { + if t.UpdatedMetrics != nil { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + } + } + }, + AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range tracers { + if t.AcknowledgedPacket != nil { + t.AcknowledgedPacket(encLevel, pn) + } + } + }, + LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range tracers { + if t.LostPacket != nil { + t.LostPacket(encLevel, pn, reason) + } + } + }, + UpdatedCongestionState: func(state CongestionState) { + for _, t := range tracers { + if t.UpdatedCongestionState != nil { + t.UpdatedCongestionState(state) + } + } + }, + UpdatedPTOCount: func(value uint32) { + for _, t := range tracers { + if t.UpdatedPTOCount != nil { + t.UpdatedPTOCount(value) + } + } + }, + UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range tracers { + if t.UpdatedKeyFromTLS != nil { + t.UpdatedKeyFromTLS(encLevel, perspective) + } + } + }, + UpdatedKey: func(generation KeyPhase, remote bool) { + for _, t := range tracers { + if t.UpdatedKey != nil { + t.UpdatedKey(generation, remote) + } + } + }, + DroppedEncryptionLevel: func(encLevel EncryptionLevel) { + for _, t := range tracers { + if t.DroppedEncryptionLevel != nil { + t.DroppedEncryptionLevel(encLevel) + } + } + }, + DroppedKey: func(generation KeyPhase) { + for _, t := range tracers { + if t.DroppedKey != nil { + t.DroppedKey(generation) + } + } + }, + SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range tracers { + if t.SetLossTimer != nil { + t.SetLossTimer(typ, encLevel, exp) + } + } + }, + LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) { + for _, t := range tracers { + if t.LossTimerExpired != nil { + t.LossTimerExpired(typ, encLevel) + } + } + }, + LossTimerCanceled: func() { + for _, t := range tracers { + if t.LossTimerCanceled != nil { + t.LossTimerCanceled() + } + } + }, + ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { + for _, t := range tracers { + if t.ECNStateUpdated != nil { + t.ECNStateUpdated(state, trigger) + } + } + }, + ChoseALPN: func(protocol string) { + for _, t := range tracers { + if t.ChoseALPN != nil { + t.ChoseALPN(protocol) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + Debug: func(name, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + } +} diff --git a/vendor/github.com/quic-go/quic-go/logging/interface.go b/vendor/github.com/quic-go/quic-go/logging/interface.go index 2ce8582ec..a618a1893 100644 --- a/vendor/github.com/quic-go/quic-go/logging/interface.go +++ b/vendor/github.com/quic-go/quic-go/logging/interface.go @@ -3,9 +3,6 @@ package logging import ( - "net" - "time" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" @@ -15,6 +12,8 @@ import ( type ( // A ByteCount is used to count bytes. ByteCount = protocol.ByteCount + // ECN is the ECN value + ECN = protocol.ECN // A ConnectionID is a QUIC Connection ID. ConnectionID = protocol.ConnectionID // An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long. @@ -38,7 +37,7 @@ type ( // The StreamType is the type of the stream (unidirectional or bidirectional). StreamType = protocol.StreamType // The VersionNumber is the QUIC version. - VersionNumber = protocol.VersionNumber + VersionNumber = protocol.Version // The Header is the QUIC packet header, before removing header protection. Header = wire.Header @@ -58,6 +57,19 @@ type ( RTTStats = utils.RTTStats ) +const ( + // ECNUnsupported means that no ECN value was set / received + ECNUnsupported = protocol.ECNUnsupported + // ECTNot is Not-ECT + ECTNot = protocol.ECNNon + // ECT0 is ECT(0) + ECT0 = protocol.ECT0 + // ECT1 is ECT(1) + ECT1 = protocol.ECT1 + // ECNCE is CE + ECNCE = protocol.ECNCE +) + const ( // KeyPhaseZero is key phase bit 0 KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero @@ -97,43 +109,3 @@ type ShortHeader struct { PacketNumberLen protocol.PacketNumberLen KeyPhase KeyPhaseBit } - -// A Tracer traces events. -type Tracer interface { - SentPacket(net.Addr, *Header, ByteCount, []Frame) - SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) - DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) -} - -// A ConnectionTracer records events. -type ConnectionTracer interface { - StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) - NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) - ClosedConnection(error) - SentTransportParameters(*TransportParameters) - ReceivedTransportParameters(*TransportParameters) - RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) - ReceivedRetry(*Header) - ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) - ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) - BufferedPacket(PacketType, ByteCount) - DroppedPacket(PacketType, ByteCount, PacketDropReason) - UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket(EncryptionLevel, PacketNumber) - LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) - UpdatedCongestionState(CongestionState) - UpdatedPTOCount(value uint32) - UpdatedKeyFromTLS(EncryptionLevel, Perspective) - UpdatedKey(generation KeyPhase, remote bool) - DroppedEncryptionLevel(EncryptionLevel) - DroppedKey(generation KeyPhase) - SetLossTimer(TimerType, EncryptionLevel, time.Time) - LossTimerExpired(TimerType, EncryptionLevel) - LossTimerCanceled() - // Close is called when the connection is closed. - Close() - Debug(name, msg string) -} diff --git a/vendor/github.com/quic-go/quic-go/logging/mockgen.go b/vendor/github.com/quic-go/quic-go/logging/mockgen.go deleted file mode 100644 index d50916799..000000000 --- a/vendor/github.com/quic-go/quic-go/logging/mockgen.go +++ /dev/null @@ -1,4 +0,0 @@ -package logging - -//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_connection_tracer_test.go github.com/quic-go/quic-go/logging ConnectionTracer" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_tracer_test.go github.com/quic-go/quic-go/logging Tracer" diff --git a/vendor/github.com/quic-go/quic-go/logging/multiplex.go b/vendor/github.com/quic-go/quic-go/logging/multiplex.go deleted file mode 100644 index 672a5cdbd..000000000 --- a/vendor/github.com/quic-go/quic-go/logging/multiplex.go +++ /dev/null @@ -1,226 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -type tracerMultiplexer struct { - tracers []Tracer -} - -var _ Tracer = &tracerMultiplexer{} - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...Tracer) Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &tracerMultiplexer{tracers} -} - -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) - } -} - -func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.SentVersionNegotiationPacket(remote, dest, src, versions) - } -} - -func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(remote, typ, size, reason) - } -} - -type connTracerMultiplexer struct { - tracers []ConnectionTracer -} - -var _ ConnectionTracer = &connTracerMultiplexer{} - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &connTracerMultiplexer{tracers: tracers} -} - -func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range m.tracers { - t.StartedConnection(local, remote, srcConnID, destConnID) - } -} - -func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { - for _, t := range m.tracers { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } -} - -func (m *connTracerMultiplexer) ClosedConnection(e error) { - for _, t := range m.tracers { - t.ClosedConnection(e) - } -} - -func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.SentTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.ReceivedTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.RestoredTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentLongHeaderPacket(hdr, size, ack, frames) - } -} - -func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentShortHeaderPacket(hdr, size, ack, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(dest, src, versions) - } -} - -func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { - for _, t := range m.tracers { - t.ReceivedRetry(hdr) - } -} - -func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedLongHeaderPacket(hdr, size, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedShortHeaderPacket(hdr, size, frames) - } -} - -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType, size ByteCount) { - for _, t := range m.tracers { - t.BufferedPacket(typ, size) - } -} - -func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(typ, size, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { - for _, t := range m.tracers { - t.UpdatedCongestionState(state) - } -} - -func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { - for _, t := range m.tracers { - t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) - } -} - -func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range m.tracers { - t.AcknowledgedPacket(encLevel, pn) - } -} - -func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range m.tracers { - t.LostPacket(encLevel, pn, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { - for _, t := range m.tracers { - t.UpdatedPTOCount(value) - } -} - -func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range m.tracers { - t.UpdatedKeyFromTLS(encLevel, perspective) - } -} - -func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { - for _, t := range m.tracers { - t.UpdatedKey(generation, remote) - } -} - -func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.DroppedEncryptionLevel(encLevel) - } -} - -func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { - for _, t := range m.tracers { - t.DroppedKey(generation) - } -} - -func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range m.tracers { - t.SetLossTimer(typ, encLevel, exp) - } -} - -func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.LossTimerExpired(typ, encLevel) - } -} - -func (m *connTracerMultiplexer) LossTimerCanceled() { - for _, t := range m.tracers { - t.LossTimerCanceled() - } -} - -func (m *connTracerMultiplexer) Debug(name, msg string) { - for _, t := range m.tracers { - t.Debug(name, msg) - } -} - -func (m *connTracerMultiplexer) Close() { - for _, t := range m.tracers { - t.Close() - } -} diff --git a/vendor/github.com/quic-go/quic-go/logging/null_tracer.go b/vendor/github.com/quic-go/quic-go/logging/null_tracer.go deleted file mode 100644 index de9703857..000000000 --- a/vendor/github.com/quic-go/quic-go/logging/null_tracer.go +++ /dev/null @@ -1,58 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -// The NullTracer is a Tracer that does nothing. -// It is useful for embedding. -type NullTracer struct{} - -var _ Tracer = &NullTracer{} - -func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} -func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {} - -// The NullConnectionTracer is a ConnectionTracer that does nothing. -// It is useful for embedding. -type NullConnectionTracer struct{} - -var _ ConnectionTracer = &NullConnectionTracer{} - -func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { -} - -func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { -} -func (n NullConnectionTracer) ClosedConnection(err error) {} -func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {} -func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {} -func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullConnectionTracer) ReceivedRetry(*Header) {} -func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, []Frame) {} -func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, []Frame) {} -func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} -func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} - -func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { -} -func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} -func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} -func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} -func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} -func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} -func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} -func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} -func (n NullConnectionTracer) DroppedKey(KeyPhase) {} -func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} -func (n NullConnectionTracer) LossTimerExpired(timerType TimerType, level EncryptionLevel) {} -func (n NullConnectionTracer) LossTimerCanceled() {} -func (n NullConnectionTracer) Close() {} -func (n NullConnectionTracer) Debug(name, msg string) {} diff --git a/vendor/github.com/quic-go/quic-go/logging/tracer.go b/vendor/github.com/quic-go/quic-go/logging/tracer.go new file mode 100644 index 000000000..edd85dbaa --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/tracer.go @@ -0,0 +1,59 @@ +package logging + +import "net" + +// A Tracer traces events. +type Tracer struct { + SentPacket func(net.Addr, *Header, ByteCount, []Frame) + SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) + DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) + Debug func(name, msg string) + Close func() +} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &Tracer{ + SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range tracers { + if t.SentPacket != nil { + t.SentPacket(remote, hdr, size, frames) + } + } + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.SentVersionNegotiationPacket != nil { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + } + } + }, + DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(remote, typ, size, reason) + } + } + }, + Debug: func(name, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + } +} diff --git a/vendor/github.com/quic-go/quic-go/logging/types.go b/vendor/github.com/quic-go/quic-go/logging/types.go index ad8006923..0d79b0a90 100644 --- a/vendor/github.com/quic-go/quic-go/logging/types.go +++ b/vendor/github.com/quic-go/quic-go/logging/types.go @@ -92,3 +92,37 @@ const ( // CongestionStateApplicationLimited means that the congestion controller is application limited CongestionStateApplicationLimited ) + +// ECNState is the state of the ECN state machine (see Appendix A.4 of RFC 9000) +type ECNState uint8 + +const ( + // ECNStateTesting is the testing state + ECNStateTesting ECNState = 1 + iota + // ECNStateUnknown is the unknown state + ECNStateUnknown + // ECNStateFailed is the failed state + ECNStateFailed + // ECNStateCapable is the capable state + ECNStateCapable +) + +// ECNStateTrigger is a trigger for an ECN state transition. +type ECNStateTrigger uint8 + +const ( + ECNTriggerNoTrigger ECNStateTrigger = iota + // ECNFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets, + // but doesn't contain any ECN counts + ECNFailedNoECNCounts + // ECNFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts + ECNFailedDecreasedECNCounts + // ECNFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost + ECNFailedLostAllTestingPackets + // ECNFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent + ECNFailedMoreECNCountsThanSent + // ECNFailedTooFewECNCounts is emitted when an ACK contains fewer ECN counts than it acknowledges packets + ECNFailedTooFewECNCounts + // ECNFailedManglingDetected is emitted when the path marks all ECN-marked packets as CE + ECNFailedManglingDetected +) diff --git a/vendor/github.com/quic-go/quic-go/mockgen.go b/vendor/github.com/quic-go/quic-go/mockgen.go index 221c1367f..81cc4a5ef 100644 --- a/vendor/github.com/quic-go/quic-go/mockgen.go +++ b/vendor/github.com/quic-go/quic-go/mockgen.go @@ -2,76 +2,75 @@ package quic -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" type SendConn = sendConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn" type RawConn = rawConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" type Sender = sender -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI" type StreamI = streamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream" type CryptoStream = cryptoStream -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI" type ReceiveStreamI = receiveStreamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI" type SendStreamI = sendStreamI -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter" type StreamGetter = streamGetter -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" type StreamSender = streamSender -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler" type CryptoDataHandler = cryptoDataHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" type FrameSource = frameSource -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" type AckFrameSource = ackFrameSource -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager" type StreamManager = streamManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" type SealingManager = sealingManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker" type Unpacker = unpacker -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer" type Packer = packer -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer" type MTUDiscoverer = mtuDiscoverer -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner" type ConnRunner = connRunner -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn" type QUICConn = quicConn -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unknown_packet_handler_test.go github.com/quic-go/quic-go UnknownPacketHandler" -type UnknownPacketHandler = unknownPacketHandler +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore" -//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/vendor/github.com/quic-go/quic-go/oss-fuzz.sh b/vendor/github.com/quic-go/quic-go/oss-fuzz.sh index 1efe6baa6..22a577fe1 100644 --- a/vendor/github.com/quic-go/quic-go/oss-fuzz.sh +++ b/vendor/github.com/quic-go/quic-go/oss-fuzz.sh @@ -3,12 +3,12 @@ # Install Go manually, since oss-fuzz ships with an outdated Go version. # See https://github.com/google/oss-fuzz/pull/10643. export CXX="${CXX} -lresolv" # required by Go 1.20 -wget https://go.dev/dl/go1.20.5.linux-amd64.tar.gz \ +wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \ && mkdir temp-go \ && rm -rf /root/.go/* \ - && tar -C temp-go/ -xzf go1.20.5.linux-amd64.tar.gz \ + && tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \ && mv temp-go/go/* /root/.go/ \ - && rm -rf temp-go go1.20.5.linux-amd64.tar.gz + && rm -rf temp-go go1.22.0.linux-amd64.tar.gz ( # fuzz qpack diff --git a/vendor/github.com/quic-go/quic-go/packet_handler_map.go b/vendor/github.com/quic-go/quic-go/packet_handler_map.go index e0f0567d7..7840202cc 100644 --- a/vendor/github.com/quic-go/quic-go/packet_handler_map.go +++ b/vendor/github.com/quic-go/quic-go/packet_handler_map.go @@ -4,7 +4,6 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" - "errors" "hash" "io" "net" @@ -21,14 +20,17 @@ type connCapabilities struct { DF bool // GSO (Generic Segmentation Offload) supported GSO bool + // ECN (Explicit Congestion Notifications) supported + ECN bool } // rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { ReadPacket() (receivedPacket, error) // WritePacket writes a packet on the wire. - // If GSO is enabled, it's the caller's responsibility to set the correct control message. - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + // gsoSize is the size of a single packet, or 0 to disable GSO. + // It is invalid to set gsoSize if capabilities.GSO is not set. + WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) LocalAddr() net.Addr SetReadDeadline(time.Time) error io.Closer @@ -42,13 +44,6 @@ type closePacket struct { info packetInfo } -type unknownPacketHandler interface { - handlePacket(receivedPacket) - setCloseError(error) -} - -var errListenerAlreadySet = errors.New("listener already set") - type packetHandlerMap struct { mutex sync.Mutex handlers map[protocol.ConnectionID]packetHandler @@ -134,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) return true } -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool { h.mutex.Lock() defer h.mutex.Unlock() @@ -142,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } - conn, ok := fn() - if !ok { - return false - } - h.handlers[clientDestConnID] = conn - h.handlers[newConnID] = conn + h.handlers[clientDestConnID] = handler + h.handlers[newConnID] = handler h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } @@ -173,18 +164,17 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { // Depending on which side closed the connection, we need to: // * remote close: absorb delayed packets // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost -func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) { +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) { var handler packetHandler if connClosePacket != nil { handler = newClosedLocalConn( func(addr net.Addr, info packetInfo) { h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) }, - pers, h.logger, ) } else { - handler = newClosedRemoteConn(pers) + handler = newClosedRemoteConn() } h.mutex.Lock() @@ -196,7 +186,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() - handler.shutdown() for _, id := range ids { delete(h.handlers, id) } @@ -225,23 +214,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) ( return handler, ok } -func (h *packetHandlerMap) CloseServer() { - h.mutex.Lock() - var wg sync.WaitGroup - for _, handler := range h.handlers { - if handler.getPerspective() == protocol.PerspectiveServer { - wg.Add(1) - go func(handler packetHandler) { - // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - handler.shutdown() - wg.Done() - }(handler) - } - } - h.mutex.Unlock() - wg.Wait() -} - func (h *packetHandlerMap) Close(e error) { h.mutex.Lock() diff --git a/vendor/github.com/quic-go/quic-go/packet_packer.go b/vendor/github.com/quic-go/quic-go/packet_packer.go index 577f3b043..e707734fb 100644 --- a/vendor/github.com/quic-go/quic-go/packet_packer.go +++ b/vendor/github.com/quic-go/quic-go/packet_packer.go @@ -1,9 +1,13 @@ package quic import ( + crand "crypto/rand" + "encoding/binary" "errors" "fmt" + "golang.org/x/exp/rand" + "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" @@ -14,13 +18,13 @@ import ( var errNothingToPack = errors.New("nothing to pack") type packer interface { - PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) - PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) - PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) - PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) - PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) + PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) + AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) SetToken([]byte) } @@ -67,6 +71,11 @@ type coalescedPacket struct { shortHdrPacket *shortHeaderPacket } +// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets). +func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool { + return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil +} + func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). switch p.header.Type { @@ -97,8 +106,8 @@ type sealingManager interface { type frameSource interface { HasData() bool - AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) } type ackFrameSource interface { @@ -122,6 +131,7 @@ type packetPacker struct { acks ackFrameSource datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue + rand rand.Rand numNonAckElicitingAcks int } @@ -140,6 +150,9 @@ func newPacketPacker( datagramQueue *datagramQueue, perspective protocol.Perspective, ) *packetPacker { + var b [8]byte + _, _ = crand.Read(b[:]) + return &packetPacker{ cryptoSetup: cryptoSetup, getDestConnID: getDestConnID, @@ -151,12 +164,13 @@ func newPacketPacker( perspective: perspective, framer: framer, acks: acks, + rand: *rand.New(rand.NewSource(binary.BigEndian.Uint64(b[:]))), pnManager: packetNumberManager, } } // PackConnectionClose packs a packet that closes the connection with a transport error. -func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { var reason string // don't send details of crypto errors if !e.ErrorCode.IsCryptoError() { @@ -166,7 +180,7 @@ func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize } // PackApplicationClose packs a packet that closes the connection with an application error. -func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v) } @@ -176,7 +190,7 @@ func (p *packetPacker) packConnectionClose( frameType uint64, reason string, maxPacketSize protocol.ByteCount, - v protocol.VersionNumber, + v protocol.Version, ) (*coalescedPacket, error) { var sealers [4]sealer var hdrs [3]*wire.ExtendedHeader @@ -279,7 +293,7 @@ func (p *packetPacker) packConnectionClose( // longHeaderPacketLength calculates the length of a serialized long header packet. // It takes into account that packets that have a tiny payload need to be padded, // such that len(payload) + packet number len >= 4 + AEAD overhead -func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount { +func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(hdr.PacketNumberLen) if pl.length < 4-pnLen { @@ -314,7 +328,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload @@ -428,7 +442,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { buf := getPacketBuffer() packet, err := p.appendPacket(buf, true, maxPacketSize, v) return packet, buf, err @@ -436,11 +450,11 @@ func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v pro // AppendPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { +func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { return p.appendPacket(buf, false, maxPacketSize, v) } -func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { +func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, err @@ -457,7 +471,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) } -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) { if onlyAck { if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { return p.getLongHeader(encLevel, v), payload{ @@ -529,7 +543,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en return hdr, pl } -func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { +func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) { if p.perspective != protocol.PerspectiveClient { return nil, payload{} } @@ -539,12 +553,12 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v) } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v) } -func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { +func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) // check if we have anything to send @@ -567,7 +581,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, return pl } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { if onlyAck { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { return payload{ack: ack, length: ack.Length(v)} @@ -575,12 +589,11 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc return payload{} } - pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)} - hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasAppData() var hasAck bool + var pl payload if ackAllowed { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { pl.ack = ack @@ -592,11 +605,17 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc if p.datagramQueue != nil { if f := p.datagramQueue.Peek(); f != nil { size := f.Length(v) - if size <= maxFrameSize-pl.length { + if size <= maxFrameSize-pl.length { // DATAGRAM frame fits pl.frames = append(pl.frames, ackhandler.Frame{Frame: f}) pl.length += size p.datagramQueue.Pop() + } else if !hasAck { + // The DATAGRAM frame doesn't fit, and the packet doesn't contain an ACK. + // Discard this frame. There's no point in retrying this in the next packet, + // as it's unlikely that the available packet size will increase. + p.datagramQueue.Pop() } + // If the DATAGRAM frame was too large and the packet contained an ACK, we'll try to send it out later. } } @@ -626,7 +645,13 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc pl.length += lengthAdded // add handlers for the control frames that were added for i := startLen; i < len(pl.frames); i++ { - pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() + switch pl.frames[i].Frame.(type) { + case *wire.PathChallengeFrame, *wire.PathResponseFrame: + // Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet. + // PATH_CHALLENGE and PATH_RESPONSE are never retransmitted. + default: + pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() + } } pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v) @@ -635,7 +660,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc return pl } -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { if encLevel == protocol.Encryption1RTT { s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -701,7 +726,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m return packet, nil } -func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { pl := payload{ frames: []ackhandler.Frame{ping}, length: ping.Frame.Length(v), @@ -719,7 +744,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B return packet, buffer, err } -func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { +func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) hdr := &wire.ExtendedHeader{ PacketNumber: pn, @@ -742,7 +767,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protoc return hdr } -func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { +func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if pl.length < 4-pnLen { @@ -788,7 +813,7 @@ func (p *packetPacker) appendShortHeaderPacket( padding, maxPacketSize protocol.ByteCount, sealer sealer, isMTUProbePacket bool, - v protocol.VersionNumber, + v protocol.Version, ) (shortHeaderPacket, error) { var paddingLen protocol.ByteCount if pl.length < 4-protocol.ByteCount(pnLen) { @@ -832,7 +857,9 @@ func (p *packetPacker) appendShortHeaderPacket( }, nil } -func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { +// appendPacketPayload serializes the payload of a packet into the raw byte slice. +// It modifies the order of payload.frames. +func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) { payloadOffset := len(raw) if pl.ack != nil { var err error @@ -844,6 +871,11 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr if paddingLen > 0 { raw = append(raw, make([]byte, paddingLen)...) } + // Randomize the order of the control frames. + // This makes sure that the receiver doesn't rely on the order in which frames are packed. + if len(pl.frames) > 1 { + p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] }) + } for _, f := range pl.frames { var err error raw, err = f.Frame.Append(raw, v) diff --git a/vendor/github.com/quic-go/quic-go/packet_unpacker.go b/vendor/github.com/quic-go/quic-go/packet_unpacker.go index 103524c7d..1034aab1d 100644 --- a/vendor/github.com/quic-go/quic-go/packet_unpacker.go +++ b/vendor/github.com/quic-go/quic-go/packet_unpacker.go @@ -53,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If any other error occurred when parsing the header, the error is of type headerParseError. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD. -func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) { +func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) { var encLevel protocol.EncryptionLevel var extHdr *wire.ExtendedHeader var decrypted []byte @@ -125,7 +125,7 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot return pn, pnLen, kp, decrypted, nil } -func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) { +func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) { extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker @@ -187,7 +187,7 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int } // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. -func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { +func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) { extHdr, err := unpackLongHeader(hd, hdr, data, v) if err != nil && err != wire.ErrInvalidReservedBits { return nil, &headerParseError{err: err} @@ -195,7 +195,7 @@ func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, return extHdr, err } -func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { +func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) { r := bytes.NewReader(data) hdrLen := hdr.ParsedLen() diff --git a/vendor/github.com/quic-go/quic-go/receive_stream.go b/vendor/github.com/quic-go/quic-go/receive_stream.go index 89d02b737..1235ff0e7 100644 --- a/vendor/github.com/quic-go/quic-go/receive_stream.go +++ b/vendor/github.com/quic-go/quic-go/receive_stream.go @@ -292,10 +292,6 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) return newlyRcvdFinalOffset, nil } -func (s *receiveStream) CloseRemote(offset protocol.ByteCount) { - s.handleStreamFrame(&wire.StreamFrame{Fin: true, Offset: offset}) -} - func (s *receiveStream) SetReadDeadline(t time.Time) error { s.mutex.Lock() s.deadline = t diff --git a/vendor/github.com/quic-go/quic-go/retransmission_queue.go b/vendor/github.com/quic-go/quic-go/retransmission_queue.go index 909ad6224..14ae1e3b8 100644 --- a/vendor/github.com/quic-go/quic-go/retransmission_queue.go +++ b/vendor/github.com/quic-go/quic-go/retransmission_queue.go @@ -74,7 +74,7 @@ func (q *retransmissionQueue) addAppData(f wire.Frame) { q.appData = append(q.appData, f) } -func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { +func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame { if len(q.initialCryptoData) > 0 { f := q.initialCryptoData[0] newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) @@ -97,7 +97,7 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v proto return f } -func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { +func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame { if len(q.handshakeCryptoData) > 0 { f := q.handshakeCryptoData[0] newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) @@ -120,7 +120,7 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v pro return f } -func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { +func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.Version) wire.Frame { if len(q.appData) == 0 { return nil } diff --git a/vendor/github.com/quic-go/quic-go/send_conn.go b/vendor/github.com/quic-go/quic-go/send_conn.go index d8ddbc871..498ed112b 100644 --- a/vendor/github.com/quic-go/quic-go/send_conn.go +++ b/vendor/github.com/quic-go/quic-go/send_conn.go @@ -1,8 +1,6 @@ package quic import ( - "fmt" - "math" "net" "github.com/quic-go/quic-go/internal/protocol" @@ -11,7 +9,7 @@ import ( // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write(b []byte, size protocol.ByteCount) error + Write(b []byte, gsoSize uint16, ecn protocol.ECN) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr @@ -27,10 +25,12 @@ type sconn struct { logger utils.Logger - info packetInfo - oob []byte + packetInfoOOB []byte // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. gotGSOError bool + // Used to catch the error sometimes returned by the first sendmsg call on Linux, + // see https://github.com/golang/go/issues/63322. + wroteFirstPacket bool } var _ sendConn = &sconn{} @@ -46,33 +46,20 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge } oob := info.OOB() - // add 32 bytes, so we can add the UDP_SEGMENT msg + // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) - oob = append(oob, make([]byte, 32)...) - oob = oob[:l] + oob = append(oob, make([]byte, 64)...)[:l] return &sconn{ - rawConn: c, - localAddr: localAddr, - remoteAddr: remote, - info: info, - oob: oob, - logger: logger, + rawConn: c, + localAddr: localAddr, + remoteAddr: remote, + packetInfoOOB: oob, + logger: logger, } } -func (c *sconn) Write(p []byte, size protocol.ByteCount) error { - if !c.capabilities().GSO { - if protocol.ByteCount(len(p)) != size { - panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size)) - } - _, err := c.WritePacket(p, c.remoteAddr, c.oob) - return err - } - // GSO is supported. Append the control message and send. - if size > math.MaxUint16 { - panic("size overflow") - } - _, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size))) +func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { + err := c.writePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true @@ -82,10 +69,10 @@ func (c *sconn) Write(p []byte, size protocol.ByteCount) error { // send out the packets one by one for len(p) > 0 { l := len(p) - if l > int(size) { - l = int(size) + if l > int(gsoSize) { + l = int(gsoSize) } - if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { + if err := c.writePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { return err } p = p[l:] @@ -95,6 +82,15 @@ func (c *sconn) Write(p []byte, size protocol.ByteCount) error { return err } +func (c *sconn) writePacket(p []byte, addr net.Addr, oob []byte, gsoSize uint16, ecn protocol.ECN) error { + _, err := c.WritePacket(p, addr, oob, gsoSize, ecn) + if err != nil && !c.wroteFirstPacket && isPermissionError(err) { + _, err = c.WritePacket(p, addr, oob, gsoSize, ecn) + } + c.wroteFirstPacket = true + return err +} + func (c *sconn) capabilities() connCapabilities { capabilities := c.rawConn.capabilities() if capabilities.GSO { diff --git a/vendor/github.com/quic-go/quic-go/send_queue.go b/vendor/github.com/quic-go/quic-go/send_queue.go index a9f7ca1a0..bde023348 100644 --- a/vendor/github.com/quic-go/quic-go/send_queue.go +++ b/vendor/github.com/quic-go/quic-go/send_queue.go @@ -3,7 +3,7 @@ package quic import "github.com/quic-go/quic-go/internal/protocol" type sender interface { - Send(p *packetBuffer, packetSize protocol.ByteCount) + Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) Run() error WouldBlock() bool Available() <-chan struct{} @@ -11,8 +11,9 @@ type sender interface { } type queueEntry struct { - buf *packetBuffer - size protocol.ByteCount + buf *packetBuffer + gsoSize uint16 + ecn protocol.ECN } type sendQueue struct { @@ -40,9 +41,9 @@ func newSendQueue(conn sendConn) sender { // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { +func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) { select { - case h.queue <- queueEntry{buf: p, size: size}: + case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -77,7 +78,7 @@ func (h *sendQueue) Run() error { // make sure that all queued packets are actually sent out shouldClose = true case e := <-h.queue: - if err := h.conn.Write(e.buf.Data, e.size); err != nil { + if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and diff --git a/vendor/github.com/quic-go/quic-go/send_stream.go b/vendor/github.com/quic-go/quic-go/send_stream.go index 4113d9f0c..e1ce3e677 100644 --- a/vendor/github.com/quic-go/quic-go/send_stream.go +++ b/vendor/github.com/quic-go/quic-go/send_stream.go @@ -18,7 +18,7 @@ type sendStreamI interface { SendStream handleStopSendingFrame(*wire.StopSendingFrame) hasData() bool - popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (frame ackhandler.StreamFrame, ok, hasMore bool) closeForShutdown(error) updateSendWindow(protocol.ByteCount) } @@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) { +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (af ackhandler.StreamFrame, ok, hasMore bool) { s.mutex.Lock() f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { @@ -215,7 +215,7 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers }, true, hasMoreData } -func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { +func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more data to send */) { if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { return nil, false } @@ -269,12 +269,12 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun return f, hasMoreData } -func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) { +func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool) { if s.nextFrame != nil { nextFrame := s.nextFrame s.nextFrame = nil - maxDataLen := utils.Min(sendWindow, nextFrame.MaxDataLen(maxBytes, v)) + maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v)) if nextFrame.DataLen() > maxDataLen { s.nextFrame = wire.GetStreamFrame() s.nextFrame.StreamID = s.streamID @@ -304,17 +304,17 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, return f, hasMoreData } -func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool { +func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.Version) bool { maxDataLen := f.MaxDataLen(maxBytes, v) if maxDataLen == 0 { // a STREAM frame must have at least one byte of data return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } - s.getDataForWriting(f, utils.Min(maxDataLen, sendWindow)) + s.getDataForWriting(f, min(maxDataLen, sendWindow)) return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } -func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) { +func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.Version) (*wire.StreamFrame, bool /* has more retransmissions */) { f := s.retransmissionQueue[0] newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v) if needsSplit { @@ -404,11 +404,13 @@ func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool } func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { + updated := s.flowController.UpdateSendWindow(limit) + if !updated { // duplicate or reordered MAX_STREAM_DATA frame + return + } s.mutex.Lock() hasStreamData := s.dataForWriting != nil || s.nextFrame != nil s.mutex.Unlock() - - s.flowController.UpdateSendWindow(limit) if hasStreamData { s.sender.onHasStreamData(s.streamID) } diff --git a/vendor/github.com/quic-go/quic-go/server.go b/vendor/github.com/quic-go/quic-go/server.go index c06228c91..afbd18fd2 100644 --- a/vendor/github.com/quic-go/quic-go/server.go +++ b/vendor/github.com/quic-go/quic-go/server.go @@ -2,13 +2,11 @@ package quic import ( "context" - "crypto/rand" "crypto/tls" "errors" "fmt" "net" "sync" - "sync/atomic" "time" "github.com/quic-go/quic-go/internal/handshake" @@ -25,17 +23,15 @@ var ErrServerClosed = errors.New("quic: server closed") // packetHandler handles packets type packetHandler interface { handlePacket(receivedPacket) - shutdown() destroy(error) - getPerspective() protocol.Perspective + closeWithTransportError(qerr.TransportErrorCode) } type packetHandlerManager interface { Get(protocol.ConnectionID) (packetHandler, bool) GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) - AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool + AddWithConnID(destConnID, newConnID protocol.ConnectionID, h packetHandler) bool Close(error) - CloseServer() connRunner } @@ -43,11 +39,9 @@ type quicConn interface { EarlyConnection earlyConnReady() <-chan struct{} handlePacket(receivedPacket) - GetVersion() protocol.VersionNumber - getPerspective() protocol.Perspective run() error destroy(error) - shutdown() + closeWithTransportError(TransportErrorCode) } type zeroRTTQueue struct { @@ -55,11 +49,15 @@ type zeroRTTQueue struct { expiration time.Time } +type rejectedPacket struct { + receivedPacket + hdr *wire.Header +} + // A Listener of QUIC type baseServer struct { - mutex sync.Mutex - - acceptEarlyConns bool + disableVersionNegotiation bool + acceptEarlyConns bool tlsConf *tls.Config config *Config @@ -67,6 +65,7 @@ type baseServer struct { conn rawConn tokenGenerator *handshake.TokenGenerator + maxTokenAge time.Duration connIDGenerator ConnectionIDGenerator connHandler packetHandlerManager @@ -92,23 +91,27 @@ type baseServer struct { *tls.Config, *handshake.TokenGenerator, bool, /* client address validated by an address validation token */ - logging.ConnectionTracer, + *logging.ConnectionTracer, uint64, utils.Logger, - protocol.VersionNumber, + protocol.Version, ) quicConn - serverError error - errorChan chan struct{} - closed bool - running chan struct{} // closed as soon as run() returns + closeMx sync.Mutex + errorChan chan struct{} // is closed when the server is closed + closeErr error + running chan struct{} // closed as soon as run() returns + versionNegotiationQueue chan receivedPacket - invalidTokenQueue chan receivedPacket + invalidTokenQueue chan rejectedPacket + connectionRefusedQueue chan rejectedPacket + retryQueue chan rejectedPacket + + verifySourceAddress func(net.Addr) bool - connQueue chan quicConn - connQueueLen int32 // to be used as an atomic + connQueue chan quicConn - tracer logging.Tracer + tracer *logging.Tracer logger utils.Logger } @@ -124,7 +127,12 @@ func (l *Listener) Accept(ctx context.Context) (Connection, error) { return l.baseServer.Accept(ctx) } -// Close the server. All active connections will be closed. +// Close closes the listener. +// Accept will return ErrServerClosed as soon as all connections in the accept queue have been accepted. +// QUIC handshakes that are still in flight will be rejected with a CONNECTION_REFUSED error. +// The effect of closing the listener depends on how it was created: +// * if it was created using Transport.Listen, already established connections will be unaffected +// * if it was created using the Listen convenience method, all established connection will be closed immediately func (l *Listener) Close() error { return l.baseServer.Close() } @@ -207,6 +215,7 @@ func listenUDP(addr string) (*net.UDPConn, error) { // This is a convenience function. More advanced use cases should instantiate a Transport, // which offers configuration options for a more fine-grained control of the connection establishment, // including reusing the underlying UDP socket for outgoing QUIC connections. +// When closing a listener created with Listen, all established QUIC connections will be closed immediately. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { tr := &Transport{Conn: conn, isSingleUse: true} return tr.Listen(tlsConf, config) @@ -224,32 +233,37 @@ func newServer( connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, - tracer logging.Tracer, + tracer *logging.Tracer, onClose func(), + tokenGeneratorKey TokenGeneratorKey, + maxTokenAge time.Duration, + verifySourceAddress func(net.Addr) bool, + disableVersionNegotiation bool, acceptEarly bool, -) (*baseServer, error) { - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - if err != nil { - return nil, err - } +) *baseServer { s := &baseServer{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connIDGenerator: connIDGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), - versionNegotiationQueue: make(chan receivedPacket, 4), - invalidTokenQueue: make(chan receivedPacket, 4), - newConn: newConnection, - tracer: tracer, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - onClose: onClose, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), + maxTokenAge: maxTokenAge, + verifySourceAddress: verifySourceAddress, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan rejectedPacket, 4), + connectionRefusedQueue: make(chan rejectedPacket, 4), + retryQueue: make(chan rejectedPacket, 8), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + disableVersionNegotiation: disableVersionNegotiation, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} @@ -257,7 +271,7 @@ func newServer( go s.run() go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil + return s } func (s *baseServer) run() { @@ -288,6 +302,10 @@ func (s *baseServer) runSendQueue() { s.maybeSendVersionNegotiationPacket(p) case p := <-s.invalidTokenQueue: s.maybeSendInvalidToken(p) + case p := <-s.connectionRefusedQueue: + s.sendConnectionRefused(p) + case p := <-s.retryQueue: + s.sendRetry(p) } } } @@ -303,41 +321,31 @@ func (s *baseServer) accept(ctx context.Context) (quicConn, error) { case <-ctx.Done(): return nil, ctx.Err() case conn := <-s.connQueue: - atomic.AddInt32(&s.connQueueLen, -1) return conn, nil case <-s.errorChan: - return nil, s.serverError + return nil, s.closeErr } } -// Close the server func (s *baseServer) Close() error { - s.mutex.Lock() - if s.closed { - s.mutex.Unlock() - return nil - } - if s.serverError == nil { - s.serverError = ErrServerClosed - } - s.closed = true - close(s.errorChan) - s.mutex.Unlock() - - <-s.running - s.onClose() + s.close(ErrServerClosed, true) return nil } -func (s *baseServer) setCloseError(e error) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.closed { +func (s *baseServer) close(e error, notifyOnClose bool) { + s.closeMx.Lock() + if s.closeErr != nil { + s.closeMx.Unlock() return } - s.closed = true - s.serverError = e + s.closeErr = e close(s.errorChan) + <-s.running + s.closeMx.Unlock() + + if notifyOnClose { + s.onClose() + } } // Addr returns the server's network address @@ -350,7 +358,7 @@ func (s *baseServer) handlePacket(p receivedPacket) { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -363,7 +371,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -376,20 +384,20 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // drop the packet if we failed to parse the protocol version if err != nil { s.logger.Debugf("Dropping a packet with an unknown version") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } // send a Version Negotiation Packet if the client is speaking a different protocol version if !protocol.IsSupportedVersion(s.config.Versions, v) { - if s.config.DisableVersionNegotiationPackets { + if s.disableVersionNegotiation { return false } if p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -399,7 +407,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -411,7 +419,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -419,7 +427,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -430,7 +438,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -449,7 +457,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) } return false @@ -463,7 +471,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -473,7 +481,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -501,7 +509,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { continue } for _, p := range q.packets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } p.buffer.Release() @@ -525,10 +533,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { if !token.ValidateRemoteAddr(addr) { return false } - if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { + if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge { return false } - if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge { + if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() { return false } return true @@ -536,10 +544,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { - p.buffer.Release() - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } + p.buffer.Release() return errors.New("too short connection ID") } @@ -552,8 +560,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } var ( - token *handshake.Token - retrySrcConnID *protocol.ConnectionID + token *handshake.Token + retrySrcConnID *protocol.ConnectionID + clientAddrVerified bool ) origDestConnID := hdr.DestConnectionID if len(hdr.Token) > 0 { @@ -566,147 +575,162 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error token = tok } } - - clientAddrIsValid := s.validateToken(token, p.remoteAddr) - if token != nil && !clientAddrIsValid { - // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. - // We just ignore them, and act as if there was no token on this packet at all. - // This also means we might send a Retry later. - if !token.IsRetryToken { - token = nil - } else { - // For Retry tokens, we send an INVALID_ERROR if - // * the token is too old, or - // * the token is invalid, in case of a retry token. - s.enqueueInvalidToken(p) - return nil + if token != nil { + clientAddrVerified = s.validateToken(token, p.remoteAddr) + if !clientAddrVerified { + // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. + // We just ignore them, and act as if there was no token on this packet at all. + // This also means we might send a Retry later. + if !token.IsRetryToken { + token = nil + } else { + // For Retry tokens, we send an INVALID_ERROR if + // * the token is too old, or + // * the token is invalid, in case of a retry token. + select { + case s.invalidTokenQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: + default: + // drop packet if we can't send out the INVALID_TOKEN packets fast enough + p.buffer.Release() + } + return nil + } } } - if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { + + if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) { // Retry invalidates all 0-RTT packets sent. delete(s.zeroRTTQueues, hdr.DestConnectionID) - go func() { - defer p.buffer.Release() - if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error sending Retry: %s", err) - } - }() + select { + case s.retryQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: + default: + // drop packet if we can't send out Retry packets fast enough + p.buffer.Release() + } return nil } - if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { - s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) - go func() { - defer p.buffer.Release() - if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error rejecting connection: %s", err) + config := s.config + if s.config.GetConfigForClient != nil { + conf, err := s.config.GetConfigForClient(&ClientHelloInfo{ + RemoteAddr: p.remoteAddr, + AddrVerified: clientAddrVerified, + }) + if err != nil { + s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") + delete(s.zeroRTTQueues, hdr.DestConnectionID) + select { + case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}: + default: + // drop packet if we can't send out the CONNECTION_REFUSED fast enough + p.buffer.Release() } - }() - return nil + return nil + } + config = populateConfig(conf) } + var conn quicConn + tracingID := nextConnTracingID() + var tracer *logging.ConnectionTracer + if config.Tracer != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := hdr.DestConnectionID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) + } connID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } s.logger.Debugf("Changing connection ID to %s.", connID) - var conn quicConn - tracingID := nextConnTracingID() - if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { - config := s.config - if s.config.GetConfigForClient != nil { - conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) - if err != nil { - s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") - return nil, false - } - config = populateConfig(conf) - } - var tracer logging.ConnectionTracer - if config.Tracer != nil { - // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := hdr.DestConnectionID - if origDestConnID.Len() > 0 { - connID = origDestConnID - } - tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) - } - conn = s.newConn( - newSendConn(s.conn, p.remoteAddr, p.info, s.logger), - s.connHandler, - origDestConnID, - retrySrcConnID, - hdr.DestConnectionID, - hdr.SrcConnectionID, - connID, - s.connIDGenerator, - s.connHandler.GetStatelessResetToken(connID), - config, - s.tlsConf, - s.tokenGenerator, - clientAddrIsValid, - tracer, - tracingID, - s.logger, - hdr.Version, - ) - conn.handlePacket(p) - - if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { - for _, p := range q.packets { - conn.handlePacket(p) - } - delete(s.zeroRTTQueues, hdr.DestConnectionID) - } - - return conn, true - }); !added { - go func() { - defer p.buffer.Release() - if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error rejecting connection: %s", err) - } - }() + conn = s.newConn( + newSendConn(s.conn, p.remoteAddr, p.info, s.logger), + s.connHandler, + origDestConnID, + retrySrcConnID, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + s.connIDGenerator, + s.connHandler.GetStatelessResetToken(connID), + config, + s.tlsConf, + s.tokenGenerator, + clientAddrVerified, + tracer, + tracingID, + s.logger, + hdr.Version, + ) + conn.handlePacket(p) + // Adding the connection will fail if the client's chosen Destination Connection ID is already in use. + // This is very unlikely: Even if an attacker chooses a connection ID that's already in use, + // under normal circumstances the packet would just be routed to that connection. + // The only time this collision will occur if we receive the two Initial packets at the same time. + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, conn); !added { + delete(s.zeroRTTQueues, hdr.DestConnectionID) + conn.closeWithTransportError(qerr.ConnectionRefused) return nil } - go conn.run() - go s.handleNewConn(conn) - if conn == nil { - p.buffer.Release() - return nil + // Pass queued 0-RTT to the newly established connection. + if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { + for _, p := range q.packets { + conn.handlePacket(p) + } + delete(s.zeroRTTQueues, hdr.DestConnectionID) } + + go conn.run() + go func() { + if completed := s.handleNewConn(conn); !completed { + return + } + + select { + case s.connQueue <- conn: + default: + conn.closeWithTransportError(ConnectionRefused) + } + }() return nil } -func (s *baseServer) handleNewConn(conn quicConn) { - connCtx := conn.Context() +func (s *baseServer) handleNewConn(conn quicConn) bool { if s.acceptEarlyConns { - // wait until the early connection is ready (or the handshake fails) + // wait until the early connection is ready, the handshake fails, or the server is closed select { + case <-s.errorChan: + conn.closeWithTransportError(ConnectionRefused) + return false + case <-conn.Context().Done(): + return false case <-conn.earlyConnReady(): - case <-connCtx.Done(): - return - } - } else { - // wait until the handshake is complete (or fails) - select { - case <-conn.HandshakeComplete(): - case <-connCtx.Done(): - return + return true } } - - atomic.AddInt32(&s.connQueueLen, 1) + // wait until the handshake completes, fails, or the server is closed select { - case s.connQueue <- conn: - // blocks until the connection is accepted - case <-connCtx.Done(): - atomic.AddInt32(&s.connQueueLen, -1) - // don't pass connections that were already closed to Accept() + case <-s.errorChan: + conn.closeWithTransportError(ConnectionRefused) + return false + case <-conn.Context().Done(): + return false + case <-conn.HandshakeComplete(): + return true + } +} + +func (s *baseServer) sendRetry(p rejectedPacket) { + if err := s.sendRetryPacket(p); err != nil { + s.logger.Debugf("Error sending Retry packet: %s", err) } } -func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error { +func (s *baseServer) sendRetryPacket(p rejectedPacket) error { + hdr := p.hdr // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) @@ -714,7 +738,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe if err != nil { return err } - token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) + token, err := s.tokenGenerator.NewRetryToken(p.remoteAddr, hdr.DestConnectionID, srcConnID) if err != nil { return err } @@ -739,50 +763,33 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.tracer != nil { - s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) + if s.tracer != nil && s.tracer.SentPacket != nil { + s.tracer.SentPacket(p.remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported) return err } -func (s *baseServer) enqueueInvalidToken(p receivedPacket) { - select { - case s.invalidTokenQueue <- p: - default: - // it's fine to drop INVALID_TOKEN packets when we are busy - p.buffer.Release() - } -} - -func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { +func (s *baseServer) maybeSendInvalidToken(p rejectedPacket) { defer p.buffer.Release() - hdr, _, _, err := wire.ParsePacket(p.data) - if err != nil { - if s.tracer != nil { - s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Error parsing packet: %s", err) - return - } - // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. + hdr := p.hdr sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) data := p.data[:hdr.ParsedLen()+hdr.Length] extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } return @@ -795,9 +802,12 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { } } -func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error { - sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) - return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) +func (s *baseServer) sendConnectionRefused(p rejectedPacket) { + defer p.buffer.Release() + sealer, _ := handshake.NewInitialAEAD(p.hdr.DestConnectionID, protocol.PerspectiveServer, p.hdr.Version) + if err := s.sendError(p.remoteAddr, p.hdr, sealer, qerr.ConnectionRefused, p.info); err != nil { + s.logger.Debugf("Error sending CONNECTION_REFUSED error: %s", err) + } } // sendError sends the error as a response to the packet received with header hdr @@ -838,10 +848,10 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) return err } @@ -867,7 +877,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -876,10 +886,10 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/vendor/github.com/quic-go/quic-go/stream.go b/vendor/github.com/quic-go/quic-go/stream.go index ab76eaf80..ce4374d60 100644 --- a/vendor/github.com/quic-go/quic-go/stream.go +++ b/vendor/github.com/quic-go/quic-go/stream.go @@ -60,7 +60,7 @@ type streamI interface { // for sending hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) updateSendWindow(protocol.ByteCount) } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn.go b/vendor/github.com/quic-go/quic-go/sys_conn.go index f2224e4cc..71cc46070 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn.go @@ -104,7 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) { + if gsoSize != 0 { + panic("cannot use GSO with a basicConn") + } + if ecn != protocol.ECNUnsupported { + panic("cannot use ECN with a basicConn") + } return c.PacketConn.WriteTo(b, addr) } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go index 758cf7788..d761072f2 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go @@ -15,6 +15,8 @@ const ( ipv4PKTINFO = unix.IP_RECVPKTINFO ) +const ecnIPv4DataLen = 4 + // ReadBatch only returns a single packet on OSX, // see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. const batchSize = 1 diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go index a2baae3b3..a53ca2eae 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go @@ -14,6 +14,8 @@ const ( ipv4PKTINFO = 0x7 ) +const ecnIPv4DataLen = 1 + const batchSize = 8 func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go index 6a049241b..5fbf34ade 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go @@ -19,6 +19,8 @@ const ( ipv4PKTINFO = unix.IP_PKTINFO ) +const ecnIPv4DataLen = 1 + const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { @@ -95,3 +97,14 @@ func isGSOError(err error) bool { } return false } + +// The first sendmsg call on a new UDP socket sometimes errors on Linux. +// It's not clear why this happens. +// See https://github.com/golang/go/issues/63322. +func isPermissionError(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + return serr.Syscall == "sendmsg" && serr.Err == unix.EPERM + } + return false +} diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go index cace82d5d..f8d69803b 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go @@ -7,3 +7,4 @@ func forceSetSendBuffer(c any, bytes int) error { return nil } func appendUDPSegmentSizeMsg([]byte, uint16) []byte { return nil } func isGSOError(error) bool { return false } +func isPermissionError(err error) bool { return false } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_oob.go b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go index 4026a7b32..64d581c06 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_oob.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go @@ -8,9 +8,12 @@ import ( "log" "net" "net/netip" + "os" + "strconv" "sync" "syscall" "time" + "unsafe" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -56,6 +59,11 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) { return size, serr } +func isECNDisabled() bool { + disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN")) + return err == nil && disabled +} + type oobConn struct { OOBCapablePacketConn batchConn batchConn @@ -140,6 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { cap: connCapabilities{ DF: supportsDF, GSO: isGSOSupported(rawConn), + ECN: !isECNDisabled(), }, } for i := 0; i < batchSize; i++ { @@ -188,7 +197,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case ipv4PKTINFO: ip, ifIndex, ok := parseIPv4PktInfo(body) if ok { @@ -205,14 +214,14 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: - p.ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask) case unix.IPV6_PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // unsigned int ipi6_ifindex; /* send/recv interface index */ // }; if len(body) == 20 { - p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])) + p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap() p.info.ifIndex = binary.LittleEndian.Uint32(body[16:]) } else { invalidCmsgOnceV6.Do(func() { @@ -228,8 +237,26 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) { } // WritePacket writes a new packet. -// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { +func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) { + oob := packetInfoOOB + if gsoSize > 0 { + if !c.capabilities().GSO { + panic("GSO disabled") + } + oob = appendUDPSegmentSizeMsg(oob, gsoSize) + } + if ecn != protocol.ECNUnsupported { + if !c.capabilities().ECN { + panic("tried to send a ECN-marked packet although ECN is disabled") + } + if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok { + if remoteUDPAddr.IP.To4() != nil { + oob = appendIPv4ECNMsg(oob, ecn) + } else { + oob = appendIPv6ECNMsg(oob, ecn) + } + } + } n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } @@ -273,3 +300,32 @@ func (info *packetInfo) OOB() []byte { } return nil } + +func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte { + startLen := len(b) + b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_IP + h.Type = unix.IP_TOS + h.SetLen(unix.CmsgLen(ecnIPv4DataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + b[offset] = val.ToHeaderBits() + return b +} + +func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte { + startLen := len(b) + const dataLen = 4 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_IPV6 + h.Type = unix.IPV6_TCLASS + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + b[offset] = val.ToHeaderBits() + return b +} diff --git a/vendor/github.com/quic-go/quic-go/token_store.go b/vendor/github.com/quic-go/quic-go/token_store.go index 00460e502..a5c1c1852 100644 --- a/vendor/github.com/quic-go/quic-go/token_store.go +++ b/vendor/github.com/quic-go/quic-go/token_store.go @@ -3,7 +3,6 @@ package quic import ( "sync" - "github.com/quic-go/quic-go/internal/utils" list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) @@ -20,14 +19,14 @@ func newSingleOriginTokenStore(size int) *singleOriginTokenStore { func (s *singleOriginTokenStore) Add(token *ClientToken) { s.tokens[s.p] = token s.p = s.index(s.p + 1) - s.len = utils.Min(s.len+1, len(s.tokens)) + s.len = min(s.len+1, len(s.tokens)) } func (s *singleOriginTokenStore) Pop() *ClientToken { s.p = s.index(s.p - 1) token := s.tokens[s.p] s.tokens[s.p] = nil - s.len = utils.Max(s.len-1, 0) + s.len = max(s.len-1, 0) return token } diff --git a/vendor/github.com/quic-go/quic-go/tools.go b/vendor/github.com/quic-go/quic-go/tools.go index e848317f1..d00ce748d 100644 --- a/vendor/github.com/quic-go/quic-go/tools.go +++ b/vendor/github.com/quic-go/quic-go/tools.go @@ -3,6 +3,6 @@ package quic import ( - _ "github.com/golang/mock/mockgen" _ "github.com/onsi/ginkgo/v2/ginkgo" + _ "go.uber.org/mock/mockgen" ) diff --git a/vendor/github.com/quic-go/quic-go/transport.go b/vendor/github.com/quic-go/quic-go/transport.go index d8da9b1a1..ea219c112 100644 --- a/vendor/github.com/quic-go/quic-go/transport.go +++ b/vendor/github.com/quic-go/quic-go/transport.go @@ -16,6 +16,8 @@ import ( "github.com/quic-go/quic-go/logging" ) +var errListenerAlreadySet = errors.New("listener already set") + // The Transport is the central point to manage incoming and outgoing QUIC connections. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // This means that a single UDP socket can be used for listening for incoming connections, as well as @@ -39,7 +41,8 @@ type Transport struct { Conn net.PacketConn // The length of the connection ID in bytes. - // It can be 0, or any value between 4 and 18. + // It can be any value between 1 and 20. + // Due to the increased risk of collisions, it is not recommended to use connection IDs shorter than 4 bytes. // If unset, a 4 byte connection ID will be used. ConnectionIDLength int @@ -57,8 +60,38 @@ type Transport struct { // See section 10.3 of RFC 9000 for details. StatelessResetKey *StatelessResetKey + // The TokenGeneratorKey is used to encrypt session resumption tokens. + // If no key is configured, a random key will be generated. + // If multiple servers are authoritative for the same domain, they should use the same key, + // see section 8.1.3 of RFC 9000 for details. + TokenGeneratorKey *TokenGeneratorKey + + // MaxTokenAge is the maximum age of the resumption token presented during the handshake. + // These tokens allow skipping address resumption when resuming a QUIC connection, + // and are especially useful when using 0-RTT. + // If not set, it defaults to 24 hours. + // See section 8.1.3 of RFC 9000 for details. + MaxTokenAge time.Duration + + // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. + // This can be useful if version information is exchanged out-of-band. + // It has no effect for clients. + DisableVersionNegotiationPackets bool + + // VerifySourceAddress decides if a connection attempt originating from unvalidated source + // addresses first needs to go through source address validation using QUIC's Retry mechanism, + // as described in RFC 9000 section 8.1.2. + // Note that the address passed to this callback is unvalidated, and might be spoofed in case + // of an attack. + // Validating the source address adds one additional network roundtrip to the handshake, + // and should therefore only be used if a suspiciously high number of incoming connection is recorded. + // For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable + // implementation of this callback (negating its return value). + VerifySourceAddress func(net.Addr) bool + // A Tracer traces events that don't belong to a single QUIC connection. - Tracer logging.Tracer + // Tracer.Close is called when the transport is closed. + Tracer *logging.Tracer handlerMap packetHandlerManager @@ -73,7 +106,7 @@ type Transport struct { // If no ConnectionIDGenerator is set, this is set to a default. connIDGenerator ConnectionIDGenerator - server unknownPacketHandler + server *baseServer conn rawConn @@ -95,28 +128,10 @@ type Transport struct { // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(conf); err != nil { - return nil, err - } - - t.mutex.Lock() - defer t.mutex.Unlock() - - if t.server != nil { - return nil, errListenerAlreadySet - } - conf = populateServerConfig(conf) - if err := t.init(false); err != nil { - return nil, err - } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) + s, err := t.createServer(tlsConf, conf, false) if err != nil { return nil, err } - t.server = s return &Listener{baseServer: s}, nil } @@ -124,6 +139,14 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) // There can only be a single listener on any net.PacketConn. // Listen may only be called again after the current Listener was closed. func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + s, err := t.createServer(tlsConf, conf, true) + if err != nil { + return nil, err + } + return &EarlyListener{baseServer: s}, nil +} + +func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } @@ -137,16 +160,26 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen if t.server != nil { return nil, errListenerAlreadySet } - conf = populateServerConfig(conf) + conf = populateConfig(conf) if err := t.init(false); err != nil { return nil, err } - s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) - if err != nil { - return nil, err - } + s := newServer( + t.conn, + t.handlerMap, + t.connIDGenerator, + tlsConf, + conf, + t.Tracer, + t.closeServer, + *t.TokenGeneratorKey, + t.MaxTokenAge, + t.VerifySourceAddress, + t.DisableVersionNegotiationPackets, + allow0RTT, + ) t.server = s - return &EarlyListener{baseServer: s}, nil + return s, nil } // Dial dials a new connection to a remote host (not using 0-RTT). @@ -172,7 +205,6 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon onClose = func() { t.Close() } } tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 setTLSConfigServerName(tlsConf, addr, host) return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) } @@ -198,6 +230,14 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.closeQueue = make(chan closePacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4) + if t.TokenGeneratorKey == nil { + var key TokenGeneratorKey + if _, err := rand.Read(key[:]); err != nil { + t.initErr = err + return + } + t.TokenGeneratorKey = &key + } if t.ConnectionIDGenerator != nil { t.connIDGenerator = t.ConnectionIDGenerator @@ -223,7 +263,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) { if err := t.init(false); err != nil { return 0, err } - return t.conn.WritePacket(b, addr, nil) + return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported) } func (t *Transport) enqueueClosePacket(p closePacket) { @@ -241,14 +281,15 @@ func (t *Transport) runSendQueue() { case <-t.listening: return case p := <-t.closeQueue: - t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported) case p := <-t.statelessResetQueue: t.sendStatelessReset(p) } } } -// Close closes the underlying connection and waits until listen has returned. +// Close closes the underlying connection. +// If any listener was started, it will be closed as well. // It is invalid to start new listeners or connections after that. func (t *Transport) Close() error { t.close(errors.New("closing")) @@ -267,7 +308,6 @@ func (t *Transport) Close() error { } func (t *Transport) closeServer() { - t.handlerMap.CloseServer() t.mutex.Lock() t.server = nil if t.isSingleUse { @@ -295,7 +335,10 @@ func (t *Transport) close(e error) { t.handlerMap.Close(e) } if t.server != nil { - t.server.setCloseError(e) + t.server.close(e, false) + } + if t.Tracer != nil && t.Tracer.Close != nil { + t.Tracer.Close() } t.closed = true } @@ -346,20 +389,28 @@ func (t *Transport) handlePacket(p receivedPacket) { connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } p.buffer.MaybeRelease() return } - if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { - return - } + // If there's a connection associated with the connection ID, pass the packet there. if handler, ok := t.handlerMap.Get(connID); ok { handler.handlePacket(p) return } + // RFC 9000 section 10.3.1 requires that the stateless reset detection logic is run for both + // packets that cannot be associated with any connections, and for packets that can't be decrypted. + // We deviate from the RFC and ignore the latter: If a packet's connection ID is associated with an + // existing connection, it is dropped there if if it can't be decrypted. + // Stateless resets use random connection IDs, and at reasonable connection ID lengths collisions are + // exceedingly rare. In the unlikely event that a stateless reset is misrouted to an existing connection, + // it is to be expected that the next stateless reset will be correctly detected. + if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } if !wire.IsLongHeaderPacket(p.data[0]) { t.maybeSendStatelessReset(p) return @@ -409,7 +460,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) { rand.Read(data) data[0] = (data[0] & 0x7f) | 0x40 data = append(data, token[:]...) - if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) } } @@ -441,7 +492,7 @@ func (t *Transport) handleNonQUICPacket(p receivedPacket) { select { case t.nonQUICPackets <- p: default: - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } diff --git a/vendor/github.com/golang/mock/AUTHORS b/vendor/go.uber.org/mock/AUTHORS similarity index 100% rename from vendor/github.com/golang/mock/AUTHORS rename to vendor/go.uber.org/mock/AUTHORS diff --git a/vendor/github.com/golang/mock/LICENSE b/vendor/go.uber.org/mock/LICENSE similarity index 100% rename from vendor/github.com/golang/mock/LICENSE rename to vendor/go.uber.org/mock/LICENSE diff --git a/vendor/go.uber.org/mock/mockgen/generic_go118.go b/vendor/go.uber.org/mock/mockgen/generic_go118.go new file mode 100644 index 000000000..b7b449476 --- /dev/null +++ b/vendor/go.uber.org/mock/mockgen/generic_go118.go @@ -0,0 +1,130 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 +// +build go1.18 + +package main + +import ( + "errors" + "fmt" + "go/ast" + "go/token" + "strings" + + "go.uber.org/mock/mockgen/model" +) + +func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field { + if ts == nil || ts.TypeParams == nil { + return nil + } + return ts.TypeParams.List +} + +func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]model.Type) (model.Type, error) { + switch v := typ.(type) { + case *ast.IndexExpr: + m, err := p.parseType(pkg, v.X, tps) + if err != nil { + return nil, err + } + nm, ok := m.(*model.NamedType) + if !ok { + return m, nil + } + t, err := p.parseType(pkg, v.Index, tps) + if err != nil { + return nil, err + } + nm.TypeParams = &model.TypeParametersType{TypeParameters: []model.Type{t}} + return m, nil + case *ast.IndexListExpr: + m, err := p.parseType(pkg, v.X, tps) + if err != nil { + return nil, err + } + nm, ok := m.(*model.NamedType) + if !ok { + return m, nil + } + var ts []model.Type + for _, expr := range v.Indices { + t, err := p.parseType(pkg, expr, tps) + if err != nil { + return nil, err + } + ts = append(ts, t) + } + nm.TypeParams = &model.TypeParametersType{TypeParameters: ts} + return m, nil + } + return nil, nil +} + +func getIdentTypeParams(decl any) string { + if decl == nil { + return "" + } + ts, ok := decl.(*ast.TypeSpec) + if !ok { + return "" + } + if ts.TypeParams == nil || len(ts.TypeParams.List) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("[") + for i, v := range ts.TypeParams.List { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(v.Names[0].Name) + } + sb.WriteString("]") + return sb.String() +} + +func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]model.Type) ([]*model.Method, error) { + var indices []ast.Expr + var typ ast.Expr + switch v := field.Type.(type) { + case *ast.IndexExpr: + indices = []ast.Expr{v.Index} + typ = v.X + case *ast.IndexListExpr: + indices = v.Indices + typ = v.X + case *ast.UnaryExpr: + if v.Op == token.TILDE { + return nil, errConstraintInterface + } + return nil, fmt.Errorf("~T may only appear as constraint for %T", field.Type) + case *ast.BinaryExpr: + if v.Op == token.OR { + return nil, errConstraintInterface + } + return nil, fmt.Errorf("A|B may only appear as constraint for %T", field.Type) + default: + return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) + } + + nf := &ast.Field{ + Doc: field.Comment, + Names: field.Names, + Type: typ, + Tag: field.Tag, + Comment: field.Comment, + } + + it.embeddedInstTypeParams = indices + + return p.parseMethod(nf, it, iface, pkg, tps) +} + +var errConstraintInterface = errors.New("interface contains constraints") diff --git a/vendor/go.uber.org/mock/mockgen/generic_notgo118.go b/vendor/go.uber.org/mock/mockgen/generic_notgo118.go new file mode 100644 index 000000000..8a779c8b2 --- /dev/null +++ b/vendor/go.uber.org/mock/mockgen/generic_notgo118.go @@ -0,0 +1,41 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !go1.18 +// +build !go1.18 + +package main + +import ( + "fmt" + "go/ast" + + "go.uber.org/mock/mockgen/model" +) + +func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field { + return nil +} + +func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]model.Type) (model.Type, error) { + return nil, nil +} + +func getIdentTypeParams(decl any) string { + return "" +} + +func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]model.Type) ([]*model.Method, error) { + return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) +} diff --git a/vendor/github.com/golang/mock/mockgen/mockgen.go b/vendor/go.uber.org/mock/mockgen/mockgen.go similarity index 62% rename from vendor/github.com/golang/mock/mockgen/mockgen.go rename to vendor/go.uber.org/mock/mockgen/mockgen.go index 50487070e..df7d85f09 100644 --- a/vendor/github.com/golang/mock/mockgen/mockgen.go +++ b/vendor/go.uber.org/mock/mockgen/mockgen.go @@ -21,29 +21,30 @@ package main import ( "bytes" "encoding/json" + "errors" "flag" "fmt" "go/token" "io" - "io/ioutil" "log" "os" "os/exec" "path" "path/filepath" + "runtime" "sort" "strconv" "strings" "unicode" - "github.com/golang/mock/mockgen/model" - "golang.org/x/mod/modfile" toolsimports "golang.org/x/tools/imports" + + "go.uber.org/mock/mockgen/model" ) const ( - gomockImportPath = "github.com/golang/mock/gomock" + gomockImportPath = "go.uber.org/mock/gomock" ) var ( @@ -53,13 +54,19 @@ var ( ) var ( - source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") - destination = flag.String("destination", "", "Output file; defaults to stdout.") - mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") - packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") - selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") - writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") - copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") + source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") + destination = flag.String("destination", "", "Output file; defaults to stdout.") + mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") + packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") + selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") + writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") + writeSourceComment = flag.Bool("write_source_comment", true, "Writes original file (source mode) or interface names (reflect mode) comment if true.") + writeGenerateDirective = flag.Bool("write_generate_directive", false, "Add //go:generate directive to regenerate the mock") + copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") + typed = flag.Bool("typed", false, "Generate Type-safe 'Return', 'Do', 'DoAndReturn' function") + imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") + auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") + excludeInterfaces = flag.String("exclude_interfaces", "", "Comma-separated names of interfaces to be excluded") debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") showVersion = flag.Bool("version", false, "Print version.") @@ -107,19 +114,6 @@ func main() { return } - dst := os.Stdout - if len(*destination) > 0 { - if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { - log.Fatalf("Unable to create directory: %v", err) - } - f, err := os.Create(*destination) - if err != nil { - log.Fatalf("Failed opening destination file: %v", err) - } - defer f.Close() - dst = f - } - outputPackageName := *packageOut if outputPackageName == "" { // pkg.Name in reflect mode is the base name of the import path, @@ -161,7 +155,7 @@ func main() { g.mockNames = parseMockNames(*mockNames) } if *copyrightFile != "" { - header, err := ioutil.ReadFile(*copyrightFile) + header, err := os.ReadFile(*copyrightFile) if err != nil { log.Fatalf("Failed reading copyright file: %v", err) } @@ -171,7 +165,27 @@ func main() { if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { log.Fatalf("Failed generating mock: %v", err) } - if _, err := dst.Write(g.Output()); err != nil { + output := g.Output() + dst := os.Stdout + if len(*destination) > 0 { + if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { + log.Fatalf("Unable to create directory: %v", err) + } + existing, err := os.ReadFile(*destination) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Fatalf("Failed reading pre-exiting destination file: %v", err) + } + if len(existing) == len(output) && bytes.Equal(existing, output) { + return + } + f, err := os.Create(*destination) + if err != nil { + log.Fatalf("Failed opening destination file: %v", err) + } + defer f.Close() + dst = f + } + if _, err := dst.Write(output); err != nil { log.Fatalf("Failed writing to destination: %v", err) } } @@ -188,6 +202,24 @@ func parseMockNames(names string) map[string]string { return mocksMap } +func parseExcludeInterfaces(names string) map[string]struct{} { + splitNames := strings.Split(names, ",") + namesSet := make(map[string]struct{}, len(splitNames)) + for _, name := range splitNames { + if name == "" { + continue + } + + namesSet[name] = struct{}{} + } + + if len(namesSet) == 0 { + return nil + } + + return namesSet +} + func usage() { _, _ = io.WriteString(os.Stderr, usageText) flag.PrintDefaults() @@ -222,7 +254,7 @@ type generator struct { packageMap map[string]string // map from import path to package name } -func (g *generator) p(format string, args ...interface{}) { +func (g *generator) p(format string, args ...any) { fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) } @@ -274,12 +306,23 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } g.p("// Code generated by MockGen. DO NOT EDIT.") - if g.filename != "" { - g.p("// Source: %v", g.filename) - } else { - g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) + if *writeSourceComment { + if g.filename != "" { + g.p("// Source: %v", g.filename) + } else { + g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) + } } - g.p("") + g.p("//") + g.p("// Generated by this command:") + g.p("//") + // only log the name of the executable, not the full path + name := filepath.Base(os.Args[0]) + if runtime.GOOS == "windows" { + name = strings.TrimSuffix(name, ".exe") + } + g.p("//\t%v", strings.Join(append([]string{name}, os.Args[1:]...), " ")) + g.p("//") // Get all required imports, and generate unique names for them all. im := pkg.Imports() @@ -305,6 +348,16 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac packagesName := createPackageMap(sortedPaths) + definedImports := make(map[string]string, len(im)) + if *imports != "" { + for _, kv := range strings.Split(*imports, ",") { + eq := strings.Index(kv, "=") + if k, v := kv[:eq], kv[eq+1:]; k != "." { + definedImports[v] = k + } + } + } + g.packageMap = make(map[string]string, len(im)) localNames := make(map[string]bool, len(im)) for _, pth := range sortedPaths { @@ -316,11 +369,16 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac // Local names for an imported package can usually be the basename of the import path. // A couple of situations don't permit that, such as duplicate local names // (e.g. importing "html/template" and "text/template"), or where the basename is - // a keyword (e.g. "foo/case"). + // a keyword (e.g. "foo/case") or when defining a name for that by using the -imports flag. // try base0, base1, ... pkgName := base + + if _, ok := definedImports[base]; ok { + pkgName = definedImports[base] + } + i := 0 - for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { + for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() || pkgName == "any" { pkgName = base + strconv.Itoa(i) i++ } @@ -335,6 +393,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } if *writePkgComment { + // Ensure there's an empty line before the package to follow the recommendations: + // https://github.com/golang/go/wiki/CodeReviewComments#package-comments + g.p("") + g.p("// Package %v is a generated GoMock package.", outputPkgName) } g.p("package %v", outputPkgName) @@ -353,6 +415,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac g.out() g.p(")") + if *writeGenerateDirective { + g.p("//go:generate %v", strings.Join(os.Args, " ")) + } + for _, intf := range pkg.Interfaces { if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { return err @@ -371,32 +437,58 @@ func (g *generator) mockName(typeName string) string { return "Mock" + typeName } +// formattedTypeParams returns a long and short form of type param info used for +// printing. If analyzing a interface with type param [I any, O any] the result +// will be: +// "[I any, O any]", "[I, O]" +func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string) (string, string) { + if len(it.TypeParams) == 0 { + return "", "" + } + var long, short strings.Builder + long.WriteString("[") + short.WriteString("[") + for i, v := range it.TypeParams { + if i != 0 { + long.WriteString(", ") + short.WriteString(", ") + } + long.WriteString(v.Name) + short.WriteString(v.Name) + long.WriteString(fmt.Sprintf(" %s", v.Type.String(g.packageMap, pkgOverride))) + } + long.WriteString("]") + short.WriteString("]") + return long.String(), short.String() +} + func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { mockType := g.mockName(intf.Name) + longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath) g.p("") g.p("// %v is a mock of %v interface.", mockType, intf.Name) - g.p("type %v struct {", mockType) + g.p("type %v%v struct {", mockType, longTp) g.in() g.p("ctrl *gomock.Controller") - g.p("recorder *%vMockRecorder", mockType) + g.p("recorder *%vMockRecorder%v", mockType, shortTp) g.out() g.p("}") g.p("") g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) - g.p("type %vMockRecorder struct {", mockType) + g.p("type %vMockRecorder%v struct {", mockType, longTp) g.in() - g.p("mock *%v", mockType) + g.p("mock *%v%v", mockType, shortTp) g.out() g.p("}") g.p("") g.p("// New%v creates a new mock instance.", mockType) - g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) + g.p("func New%v%v(ctrl *gomock.Controller) *%v%v {", mockType, longTp, mockType, shortTp) g.in() - g.p("mock := &%v{ctrl: ctrl}", mockType) - g.p("mock.recorder = &%vMockRecorder{mock}", mockType) + g.p("mock := &%v%v{ctrl: ctrl}", mockType, shortTp) + g.p("mock.recorder = &%vMockRecorder%v{mock}", mockType, shortTp) g.p("return mock") g.out() g.p("}") @@ -404,13 +496,13 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa // XXX: possible name collision here if someone has EXPECT in their interface. g.p("// EXPECT returns an object that allows the caller to indicate expected use.") - g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType) + g.p("func (m *%v%v) EXPECT() *%vMockRecorder%v {", mockType, shortTp, mockType, shortTp) g.in() g.p("return m.recorder") g.out() g.p("}") - g.GenerateMockMethods(mockType, intf, outputPackagePath) + g.GenerateMockMethods(mockType, intf, outputPackagePath, longTp, shortTp, *typed) return nil } @@ -421,13 +513,17 @@ func (b byMethodName) Len() int { return len(b) } func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } -func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { +func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride, longTp, shortTp string, typed bool) { sort.Sort(byMethodName(intf.Methods)) for _, m := range intf.Methods { g.p("") - _ = g.GenerateMockMethod(mockType, m, pkgOverride) + _ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp) g.p("") - _ = g.GenerateMockRecorderMethod(mockType, m) + _ = g.GenerateMockRecorderMethod(intf, m, shortTp, typed) + if typed { + g.p("") + _ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp) + } } } @@ -446,9 +542,9 @@ func makeArgString(argNames, argTypes []string) string { // GenerateMockMethod generates a mock method implementation. // If non-empty, pkgOverride is the package in which unqualified types reside. -func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { - argNames := g.getArgNames(m) - argTypes := g.getArgTypes(m, pkgOverride) +func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride, shortTp string) error { + argNames := g.getArgNames(m, true /* in */) + argTypes := g.getArgTypes(m, pkgOverride, true /* in */) argString := makeArgString(argNames, argTypes) rets := make([]string, len(m.Out)) @@ -467,7 +563,7 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver idRecv := ia.allocateIdentifier("m") g.p("// %v mocks base method.", m.Name) - g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) + g.p("func (%v *%v%v) %v(%v)%v {", idRecv, mockType, shortTp, m.Name, argString, retString) g.in() g.p("%s.ctrl.T.Helper()", idRecv) @@ -477,11 +573,11 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver callArgs = ", " + strings.Join(argNames, ", ") } } else { - // Non-trivial. The generated code must build a []interface{}, + // Non-trivial. The generated code must build a []any, // but the variadic argument may be any type. idVarArgs := ia.allocateIdentifier("varargs") idVArg := ia.allocateIdentifier("a") - g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) + g.p("%s := []any{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) g.in() g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) @@ -511,8 +607,9 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver return nil } -func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error { - argNames := g.getArgNames(m) +func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, m *model.Method, shortTp string, typed bool) error { + mockType := g.mockName(intf.Name) + argNames := g.getArgNames(m, true) var argString string if m.Variadic == nil { @@ -521,21 +618,26 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) argString = strings.Join(argNames[:len(argNames)-1], ", ") } if argString != "" { - argString += " interface{}" + argString += " any" } if m.Variadic != nil { if argString != "" { argString += ", " } - argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1]) + argString += fmt.Sprintf("%s ...any", argNames[len(argNames)-1]) } ia := newIdentifierAllocator(argNames) idRecv := ia.allocateIdentifier("mr") g.p("// %v indicates an expected call of %v.", m.Name, m.Name) - g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString) + if typed { + g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, mockType, m.Name, shortTp) + } else { + g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString) + } + g.in() g.p("%s.mock.ctrl.T.Helper()", idRecv) @@ -551,42 +653,122 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) } else { // Hard: create a temporary slice. idVarArgs := ia.allocateIdentifier("varargs") - g.p("%s := append([]interface{}{%s}, %s...)", + g.p("%s := append([]any{%s}, %s...)", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "), argNames[len(argNames)-1]) callArgs = ", " + idVarArgs + "..." } } - g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs) + if typed { + g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) + g.p(`return &%s%sCall%s{Call: call}`, mockType, m.Name, shortTp) + } else { + g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs) + } g.out() g.p("}") return nil } -func (g *generator) getArgNames(m *model.Method) []string { - argNames := make([]string, len(m.In)) - for i, p := range m.In { +func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error { + mockType := g.mockName(intf.Name) + argNames := g.getArgNames(m, true /* in */) + retNames := g.getArgNames(m, false /* out */) + argTypes := g.getArgTypes(m, pkgOverride, true /* in */) + retTypes := g.getArgTypes(m, pkgOverride, false /* out */) + argString := strings.Join(argTypes, ", ") + + rets := make([]string, len(m.Out)) + for i, p := range m.Out { + rets[i] = p.Type.String(g.packageMap, pkgOverride) + } + + var retString string + switch { + case len(rets) == 1: + retString = " " + rets[0] + case len(rets) > 1: + retString = " (" + strings.Join(rets, ", ") + ")" + } + + ia := newIdentifierAllocator(argNames) + idRecv := ia.allocateIdentifier("c") + + recvStructName := mockType + m.Name + + g.p("// %s%sCall wrap *gomock.Call", mockType, m.Name) + g.p("type %s%sCall%s struct{", mockType, m.Name, longTp) + g.in() + g.p("*gomock.Call") + g.out() + g.p("}") + + g.p("// Return rewrite *gomock.Call.Return") + g.p("func (%s *%sCall%s) Return(%v) *%sCall%s {", idRecv, recvStructName, shortTp, makeArgString(retNames, retTypes), recvStructName, shortTp) + g.in() + var retArgs string + if len(retNames) > 0 { + retArgs = strings.Join(retNames, ", ") + } + g.p(`%s.Call = %v.Call.Return(%v)`, idRecv, idRecv, retArgs) + g.p("return %s", idRecv) + g.out() + g.p("}") + + g.p("// Do rewrite *gomock.Call.Do") + g.p("func (%s *%sCall%s) Do(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp) + g.in() + g.p(`%s.Call = %v.Call.Do(f)`, idRecv, idRecv) + g.p("return %s", idRecv) + g.out() + g.p("}") + + g.p("// DoAndReturn rewrite *gomock.Call.DoAndReturn") + g.p("func (%s *%sCall%s) DoAndReturn(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp) + g.in() + g.p(`%s.Call = %v.Call.DoAndReturn(f)`, idRecv, idRecv) + g.p("return %s", idRecv) + g.out() + g.p("}") + return nil +} + +func (g *generator) getArgNames(m *model.Method, in bool) []string { + var params []*model.Parameter + if in { + params = m.In + } else { + params = m.Out + } + argNames := make([]string, len(params)) + for i, p := range params { name := p.Name if name == "" || name == "_" { name = fmt.Sprintf("arg%d", i) } argNames[i] = name } - if m.Variadic != nil { + if m.Variadic != nil && in { name := m.Variadic.Name if name == "" { - name = fmt.Sprintf("arg%d", len(m.In)) + name = fmt.Sprintf("arg%d", len(params)) } argNames = append(argNames, name) } return argNames } -func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string { - argTypes := make([]string, len(m.In)) - for i, p := range m.In { +func (g *generator) getArgTypes(m *model.Method, pkgOverride string, in bool) []string { + var params []*model.Parameter + if in { + params = m.In + } else { + params = m.Out + } + argTypes := make([]string, len(params)) + for i, p := range params { argTypes[i] = p.Type.String(g.packageMap, pkgOverride) } if m.Variadic != nil { @@ -670,7 +852,7 @@ func parsePackageImport(srcDir string) (string, error) { if moduleMode != "off" { currentDir := srcDir for { - dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) + dat, err := os.ReadFile(filepath.Join(currentDir, "go.mod")) if os.IsNotExist(err) { if currentDir == filepath.Dir(currentDir) { // at the root diff --git a/vendor/github.com/golang/mock/mockgen/model/model.go b/vendor/go.uber.org/mock/mockgen/model/model.go similarity index 87% rename from vendor/github.com/golang/mock/mockgen/model/model.go rename to vendor/go.uber.org/mock/mockgen/model/model.go index 2c6a62ceb..853dbf2d6 100644 --- a/vendor/github.com/golang/mock/mockgen/model/model.go +++ b/vendor/go.uber.org/mock/mockgen/model/model.go @@ -24,7 +24,7 @@ import ( ) // pkgPath is the importable path for package model -const pkgPath = "github.com/golang/mock/mockgen/model" +const pkgPath = "go.uber.org/mock/mockgen/model" // Package is a Go package. It may be a subset. type Package struct { @@ -47,14 +47,18 @@ func (pkg *Package) Imports() map[string]bool { im := make(map[string]bool) for _, intf := range pkg.Interfaces { intf.addImports(im) + for _, tp := range intf.TypeParams { + tp.Type.addImports(im) + } } return im } // Interface is a Go interface. type Interface struct { - Name string - Methods []*Method + Name string + Methods []*Method + TypeParams []*Parameter } // Print writes the interface name and its methods. @@ -143,12 +147,14 @@ type Type interface { } func init() { - gob.Register(&ArrayType{}) - gob.Register(&ChanType{}) - gob.Register(&FuncType{}) - gob.Register(&MapType{}) - gob.Register(&NamedType{}) - gob.Register(&PointerType{}) + // Call gob.RegisterName with pkgPath as prefix to avoid conflicting with + // github.com/golang/mock/mockgen/model 's registration. + gob.RegisterName(pkgPath+".ArrayType", &ArrayType{}) + gob.RegisterName(pkgPath+".ChanType", &ChanType{}) + gob.RegisterName(pkgPath+".FuncType", &FuncType{}) + gob.RegisterName(pkgPath+".MapType", &MapType{}) + gob.RegisterName(pkgPath+".NamedType", &NamedType{}) + gob.RegisterName(pkgPath+".PointerType", &PointerType{}) // Call gob.RegisterName to make sure it has the consistent name registered // for both gob decoder and encoder. @@ -156,7 +162,7 @@ func init() { // For a non-pointer type, gob.Register will try to get package full path by // calling rt.PkgPath() for a name to register. If your project has vendor // directory, it is possible that PkgPath will get a path like this: - // ../../../vendor/github.com/golang/mock/mockgen/model + // ../../../vendor/go.uber.org/mock/mockgen/model gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType("")) } @@ -259,26 +265,28 @@ func (mt *MapType) addImports(im map[string]bool) { // NamedType is an exported type in a package. type NamedType struct { - Package string // may be empty - Type string + Package string // may be empty + Type string + TypeParams *TypeParametersType } func (nt *NamedType) String(pm map[string]string, pkgOverride string) string { if pkgOverride == nt.Package { - return nt.Type + return nt.Type + nt.TypeParams.String(pm, pkgOverride) } prefix := pm[nt.Package] if prefix != "" { - return prefix + "." + nt.Type + return prefix + "." + nt.Type + nt.TypeParams.String(pm, pkgOverride) } - return nt.Type + return nt.Type + nt.TypeParams.String(pm, pkgOverride) } func (nt *NamedType) addImports(im map[string]bool) { if nt.Package != "" { im[nt.Package] = true } + nt.TypeParams.addImports(im) } // PointerType is a pointer to another type. @@ -297,6 +305,36 @@ type PredeclaredType string func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) } func (pt PredeclaredType) addImports(map[string]bool) {} +// TypeParametersType contains type parameters for a NamedType. +type TypeParametersType struct { + TypeParameters []Type +} + +func (tp *TypeParametersType) String(pm map[string]string, pkgOverride string) string { + if tp == nil || len(tp.TypeParameters) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("[") + for i, v := range tp.TypeParameters { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(v.String(pm, pkgOverride)) + } + sb.WriteString("]") + return sb.String() +} + +func (tp *TypeParametersType) addImports(im map[string]bool) { + if tp == nil { + return + } + for _, v := range tp.TypeParameters { + v.addImports(im) + } +} + // The following code is intended to be called by the program generated by ../reflect.go. // InterfaceFromInterfaceType returns a pointer to an interface for the @@ -431,7 +469,7 @@ func typeFromType(t reflect.Type) (Type, error) { case reflect.Interface: // Two special interfaces. if t.NumMethod() == 0 { - return PredeclaredType("interface{}"), nil + return PredeclaredType("any"), nil } if t == errorType { return PredeclaredType("error"), nil diff --git a/vendor/github.com/golang/mock/mockgen/parse.go b/vendor/go.uber.org/mock/mockgen/parse.go similarity index 63% rename from vendor/github.com/golang/mock/mockgen/parse.go rename to vendor/go.uber.org/mock/mockgen/parse.go index bf6902cd5..f43321c33 100644 --- a/vendor/github.com/golang/mock/mockgen/parse.go +++ b/vendor/go.uber.org/mock/mockgen/parse.go @@ -18,7 +18,6 @@ package main import ( "errors" - "flag" "fmt" "go/ast" "go/build" @@ -26,19 +25,14 @@ import ( "go/parser" "go/token" "go/types" - "io/ioutil" "log" + "os" "path" "path/filepath" "strconv" "strings" - "github.com/golang/mock/mockgen/model" -) - -var ( - imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") - auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") + "go.uber.org/mock/mockgen/model" ) // sourceMode generates mocks via source file. @@ -62,8 +56,8 @@ func sourceMode(source string) (*model.Package, error) { p := &fileParser{ fileSet: fs, imports: make(map[string]importedPackage), - importedInterfaces: make(map[string]map[string]*ast.InterfaceType), - auxInterfaces: make(map[string]map[string]*ast.InterfaceType), + importedInterfaces: newInterfaceCache(), + auxInterfaces: newInterfaceCache(), srcDir: srcDir, } @@ -81,6 +75,10 @@ func sourceMode(source string) (*model.Package, error) { } } + if *excludeInterfaces != "" { + p.excludeNamesSet = parseExcludeInterfaces(*excludeInterfaces) + } + // Handle -aux_files. if err := p.parseAuxFiles(*auxFiles); err != nil { return nil, err @@ -127,21 +125,55 @@ func (d duplicateImport) Error() string { func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" } func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil } -type fileParser struct { - fileSet *token.FileSet - imports map[string]importedPackage // package name => imported package - importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface +type interfaceCache struct { + m map[string]map[string]*namedInterface +} + +func newInterfaceCache() *interfaceCache { + return &interfaceCache{ + m: make(map[string]map[string]*namedInterface), + } +} + +func (i *interfaceCache) Set(pkg, name string, it *namedInterface) { + if _, ok := i.m[pkg]; !ok { + i.m[pkg] = make(map[string]*namedInterface) + } + i.m[pkg][name] = it +} + +func (i *interfaceCache) Get(pkg, name string) *namedInterface { + if _, ok := i.m[pkg]; !ok { + return nil + } + return i.m[pkg][name] +} - auxFiles []*ast.File - auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface +func (i *interfaceCache) GetASTIface(pkg, name string) *ast.InterfaceType { + if _, ok := i.m[pkg]; !ok { + return nil + } + it, ok := i.m[pkg][name] + if !ok { + return nil + } + return it.it +} - srcDir string +type fileParser struct { + fileSet *token.FileSet + imports map[string]importedPackage // package name => imported package + importedInterfaces *interfaceCache + auxFiles []*ast.File + auxInterfaces *interfaceCache + srcDir string + excludeNamesSet map[string]struct{} } -func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { +func (p *fileParser) errorf(pos token.Pos, format string, args ...any) error { ps := p.fileSet.Position(pos) format = "%s:%d:%d: " + format - args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) + args = append([]any{ps.Filename, ps.Line, ps.Column}, args...) return fmt.Errorf(format, args...) } @@ -168,11 +200,8 @@ func (p *fileParser) parseAuxFiles(auxFiles string) error { } func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { - if _, ok := p.auxInterfaces[pkg]; !ok { - p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType) - } for ni := range iterInterfaces(file) { - p.auxInterfaces[pkg][ni.name.Name] = ni.it + p.auxInterfaces.Set(pkg, ni.name.Name, ni) } } @@ -199,7 +228,13 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag var is []*model.Interface for ni := range iterInterfaces(file) { - i, err := p.parseInterface(ni.name.String(), importPath, ni.it) + if _, ok := p.excludeNamesSet[ni.name.String()]; ok { + continue + } + i, err := p.parseInterface(ni.name.String(), importPath, ni) + if errors.Is(err, errConstraintInterface) { + continue + } if err != nil { return nil, err } @@ -219,8 +254,8 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) { newP := &fileParser{ fileSet: token.NewFileSet(), imports: make(map[string]importedPackage), - importedInterfaces: make(map[string]map[string]*ast.InterfaceType), - auxInterfaces: make(map[string]map[string]*ast.InterfaceType), + importedInterfaces: newInterfaceCache(), + auxInterfaces: newInterfaceCache(), srcDir: p.srcDir, } @@ -233,11 +268,8 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) { for _, pkg := range pkgs { file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) - if _, ok := newP.importedInterfaces[path]; !ok { - newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType) - } for ni := range iterInterfaces(file) { - newP.importedInterfaces[path][ni.name.Name] = ni.it + newP.importedInterfaces.Set(path, ni.name.Name, ni) } imports, _ := importsOfFile(file) for pkgName, pkgI := range imports { @@ -247,9 +279,77 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) { return newP, nil } -func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { +func (p *fileParser) constructInstParams(pkg string, params []*ast.Field, instParams []model.Type, embeddedInstParams []ast.Expr, tps map[string]model.Type) ([]model.Type, error) { + pm := make(map[string]int) + var i int + for _, v := range params { + for _, n := range v.Names { + pm[n.Name] = i + instParams = append(instParams, model.PredeclaredType(n.Name)) + i++ + } + } + + var runtimeInstParams []model.Type + for _, instParam := range embeddedInstParams { + switch t := instParam.(type) { + case *ast.Ident: + if idx, ok := pm[t.Name]; ok { + runtimeInstParams = append(runtimeInstParams, instParams[idx]) + continue + } + } + modelType, err := p.parseType(pkg, instParam, tps) + if err != nil { + return nil, err + } + runtimeInstParams = append(runtimeInstParams, modelType) + } + + return runtimeInstParams, nil +} + +func (p *fileParser) constructTps(it *namedInterface) (tps map[string]model.Type) { + tps = make(map[string]model.Type) + n := 0 + for _, tp := range it.typeParams { + for _, tm := range tp.Names { + tps[tm.Name] = nil + if len(it.instTypes) != 0 { + tps[tm.Name] = it.instTypes[n] + n++ + } + } + } + return tps +} + +// parseInterface loads interface specified by pkg and name, parses it and returns +// a new model with the parsed. +func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*model.Interface, error) { iface := &model.Interface{Name: name} - for _, field := range it.Methods.List { + tps := p.constructTps(it) + tp, err := p.parseFieldList(pkg, it.typeParams, tps) + if err != nil { + return nil, fmt.Errorf("unable to parse interface type parameters: %v", name) + } + + iface.TypeParams = tp + for _, field := range it.it.Methods.List { + var methods []*model.Method + if methods, err = p.parseMethod(field, it, iface, pkg, tps); err != nil { + return nil, err + } + for _, m := range methods { + iface.AddMethod(m) + } + } + return iface, nil +} + +func (p *fileParser) parseMethod(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]model.Type) ([]*model.Method, error) { + // {} for git diff + { switch v := field.Type.(type) { case *ast.FuncType: if nn := len(field.Names); nn != 1 { @@ -259,37 +359,55 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m Name: field.Names[0].String(), } var err error - m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) + m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v, tps) if err != nil { return nil, err } - iface.AddMethod(m) + return []*model.Method{m}, nil case *ast.Ident: // Embedded interface in this package. - embeddedIfaceType := p.auxInterfaces[pkg][v.String()] + embeddedIfaceType := p.auxInterfaces.Get(pkg, v.String()) if embeddedIfaceType == nil { - embeddedIfaceType = p.importedInterfaces[pkg][v.String()] + embeddedIfaceType = p.importedInterfaces.Get(pkg, v.String()) } var embeddedIface *model.Interface if embeddedIfaceType != nil { var err error + embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps) + if err != nil { + return nil, err + } embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) if err != nil { return nil, err } + } else { // This is built-in error interface. if v.String() == model.ErrorInterface.Name { embeddedIface = &model.ErrorInterface } else { - return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) + ip, err := p.parsePackage(pkg) + if err != nil { + return nil, p.errorf(v.Pos(), "could not parse package %s: %v", pkg, err) + } + + if embeddedIfaceType = ip.importedInterfaces.Get(pkg, v.String()); embeddedIfaceType == nil { + return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", pkg, v.String()) + } + + embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps) + if err != nil { + return nil, err + } + embeddedIface, err = ip.parseInterface(v.String(), pkg, embeddedIfaceType) + if err != nil { + return nil, err + } } } - // Copy the methods. - for _, m := range embeddedIface.Methods { - iface.AddMethod(m) - } + return embeddedIface.Methods, nil case *ast.SelectorExpr: // Embedded interface in another package. filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() @@ -300,8 +418,12 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m var embeddedIface *model.Interface var err error - embeddedIfaceType := p.auxInterfaces[filePkg][sel] + embeddedIfaceType := p.auxInterfaces.Get(filePkg, sel) if embeddedIfaceType != nil { + embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps) + if err != nil { + return nil, err + } embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) if err != nil { return nil, err @@ -320,46 +442,47 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m parser: parser, } } - if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil { + if embeddedIfaceType = parser.importedInterfaces.Get(path, sel); embeddedIfaceType == nil { return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) } + + embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps) + if err != nil { + return nil, err + } embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) if err != nil { return nil, err } } - // Copy the methods. // TODO: apply shadowing rules. - for _, m := range embeddedIface.Methods { - iface.AddMethod(m) - } + return embeddedIface.Methods, nil default: - return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) + return p.parseGenericMethod(field, it, iface, pkg, tps) } } - return iface, nil } -func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { +func (p *fileParser) parseFunc(pkg string, f *ast.FuncType, tps map[string]model.Type) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { if f.Params != nil { regParams := f.Params.List if isVariadic(f) { n := len(regParams) varParams := regParams[n-1:] regParams = regParams[:n-1] - vp, err := p.parseFieldList(pkg, varParams) + vp, err := p.parseFieldList(pkg, varParams, tps) if err != nil { return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) } variadic = vp[0] } - inParam, err = p.parseFieldList(pkg, regParams) + inParam, err = p.parseFieldList(pkg, regParams, tps) if err != nil { return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) } } if f.Results != nil { - outParam, err = p.parseFieldList(pkg, f.Results.List) + outParam, err = p.parseFieldList(pkg, f.Results.List, tps) if err != nil { return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) } @@ -367,7 +490,7 @@ func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Pa return } -func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { +func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field, tps map[string]model.Type) ([]*model.Parameter, error) { nf := 0 for _, f := range fields { nn := len(f.Names) @@ -382,7 +505,7 @@ func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.P ps := make([]*model.Parameter, nf) i := 0 // destination index for _, f := range fields { - t, err := p.parseType(pkg, f.Type) + t, err := p.parseType(pkg, f.Type, tps) if err != nil { return nil, err } @@ -401,44 +524,27 @@ func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.P return ps, nil } -func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { +func (p *fileParser) parseType(pkg string, typ ast.Expr, tps map[string]model.Type) (model.Type, error) { switch v := typ.(type) { case *ast.ArrayType: ln := -1 if v.Len != nil { - var value string - switch val := v.Len.(type) { - case (*ast.BasicLit): - value = val.Value - case (*ast.Ident): - // when the length is a const defined locally - value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value - case (*ast.SelectorExpr): - // when the length is a const defined in an external package - usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X)) - if err != nil { - return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err) - } - ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name) - if err != nil { - return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err) - } - value = ev.Value.String() + value, err := p.parseArrayLength(v.Len) + if err != nil { + return nil, err } - - x, err := strconv.Atoi(value) + ln, err = strconv.Atoi(value) if err != nil { return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) } - ln = x } - t, err := p.parseType(pkg, v.Elt) + t, err := p.parseType(pkg, v.Elt, tps) if err != nil { return nil, err } return &model.ArrayType{Len: ln, Type: t}, nil case *ast.ChanType: - t, err := p.parseType(pkg, v.Value) + t, err := p.parseType(pkg, v.Value, tps) if err != nil { return nil, err } @@ -452,15 +558,16 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { return &model.ChanType{Dir: dir, Type: t}, nil case *ast.Ellipsis: // assume we're parsing a variadic argument - return p.parseType(pkg, v.Elt) + return p.parseType(pkg, v.Elt, tps) case *ast.FuncType: - in, variadic, out, err := p.parseFunc(pkg, v) + in, variadic, out, err := p.parseFunc(pkg, v, tps) if err != nil { return nil, err } return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil case *ast.Ident: - if v.IsExported() { + it, ok := tps[v.Name] + if v.IsExported() && !ok { // `pkg` may be an aliased imported pkg // if so, patch the import w/ the fully qualified import maybeImportedPkg, ok := p.imports[pkg] @@ -470,20 +577,22 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { // assume type in this package return &model.NamedType{Package: pkg, Type: v.Name}, nil } - + if ok && it != nil { + return it, nil + } // assume predeclared type return model.PredeclaredType(v.Name), nil case *ast.InterfaceType: if v.Methods != nil && len(v.Methods.List) > 0 { return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") } - return model.PredeclaredType("interface{}"), nil + return model.PredeclaredType("any"), nil case *ast.MapType: - key, err := p.parseType(pkg, v.Key) + key, err := p.parseType(pkg, v.Key, tps) if err != nil { return nil, err } - value, err := p.parseType(pkg, v.Value) + value, err := p.parseType(pkg, v.Value, tps) if err != nil { return nil, err } @@ -496,7 +605,7 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { } return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil case *ast.StarExpr: - t, err := p.parseType(pkg, v.X) + t, err := p.parseType(pkg, v.X, tps) if err != nil { return nil, err } @@ -507,12 +616,61 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { } return model.PredeclaredType("struct{}"), nil case *ast.ParenExpr: - return p.parseType(pkg, v.X) + return p.parseType(pkg, v.X, tps) + default: + mt, err := p.parseGenericType(pkg, typ, tps) + if err != nil { + return nil, err + } + if mt == nil { + break + } + return mt, nil } return nil, fmt.Errorf("don't know how to parse type %T", typ) } +func (p *fileParser) parseArrayLength(expr ast.Expr) (string, error) { + switch val := expr.(type) { + case (*ast.BasicLit): + return val.Value, nil + case (*ast.Ident): + // when the length is a const defined locally + return val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value, nil + case (*ast.SelectorExpr): + // when the length is a const defined in an external package + usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X)) + if err != nil { + return "", p.errorf(expr.Pos(), "unknown package in array length: %v", err) + } + ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name) + if err != nil { + return "", p.errorf(expr.Pos(), "unknown constant in array length: %v", err) + } + return ev.Value.String(), nil + case (*ast.ParenExpr): + return p.parseArrayLength(val.X) + case (*ast.BinaryExpr): + x, err := p.parseArrayLength(val.X) + if err != nil { + return "", err + } + y, err := p.parseArrayLength(val.Y) + if err != nil { + return "", err + } + biExpr := fmt.Sprintf("%s%v%s", x, val.Op, y) + tv, err := types.Eval(token.NewFileSet(), nil, token.NoPos, biExpr) + if err != nil { + return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", err) + } + return tv.Value.String(), nil + default: + return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", val) + } +} + // importsOfFile returns a map of package name to import path // of the imports in file. func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) { @@ -575,13 +733,16 @@ func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, do } type namedInterface struct { - name *ast.Ident - it *ast.InterfaceType + name *ast.Ident + it *ast.InterfaceType + typeParams []*ast.Field + embeddedInstTypeParams []ast.Expr + instTypes []model.Type } // Create an iterator over all interfaces in file. -func iterInterfaces(file *ast.File) <-chan namedInterface { - ch := make(chan namedInterface) +func iterInterfaces(file *ast.File) <-chan *namedInterface { + ch := make(chan *namedInterface) go func() { for _, decl := range file.Decls { gd, ok := decl.(*ast.GenDecl) @@ -598,7 +759,7 @@ func iterInterfaces(file *ast.File) <-chan namedInterface { continue } - ch <- namedInterface{ts.Name, it} + ch <- &namedInterface{name: ts.Name, it: it, typeParams: getTypeSpecTypeParams(ts)} } } close(ch) @@ -618,7 +779,7 @@ func isVariadic(f *ast.FuncType) bool { // packageNameOfDir get package import path via dir func packageNameOfDir(srcDir string) (string, error) { - files, err := ioutil.ReadDir(srcDir) + files, err := os.ReadDir(srcDir) if err != nil { log.Fatal(err) } diff --git a/vendor/github.com/golang/mock/mockgen/reflect.go b/vendor/go.uber.org/mock/mockgen/reflect.go similarity index 93% rename from vendor/github.com/golang/mock/mockgen/reflect.go rename to vendor/go.uber.org/mock/mockgen/reflect.go index e24efce0b..ca80ebbb6 100644 --- a/vendor/github.com/golang/mock/mockgen/reflect.go +++ b/vendor/go.uber.org/mock/mockgen/reflect.go @@ -23,7 +23,6 @@ import ( "fmt" "go/build" "io" - "io/ioutil" "log" "os" "os/exec" @@ -32,7 +31,7 @@ import ( "strings" "text/template" - "github.com/golang/mock/mockgen/model" + "go.uber.org/mock/mockgen/model" ) var ( @@ -92,7 +91,7 @@ func writeProgram(importPath string, symbols []string) ([]byte, error) { // run the given program and parse the output as a model.Package. func run(program string) (*model.Package, error) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") if err != nil { return nil, err } @@ -133,7 +132,7 @@ func run(program string) (*model.Package, error) { // parses the output as a model.Package. func runInDir(program []byte, dir string) (*model.Package, error) { // We use TempDir instead of TempFile so we can control the filename. - tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_") + tmpDir, err := os.MkdirTemp(dir, "gomock_reflect_") if err != nil { return nil, err } @@ -149,7 +148,7 @@ func runInDir(program []byte, dir string) (*model.Package, error) { progBinary += ".exe" } - if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { + if err := os.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { return nil, err } @@ -169,8 +168,8 @@ func runInDir(program []byte, dir string) (*model.Package, error) { if err := cmd.Run(); err != nil { sErr := buf.String() if strings.Contains(sErr, `cannot find package "."`) && - strings.Contains(sErr, "github.com/golang/mock/mockgen/model") { - fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://github.com/golang/mock#reflect-vendoring-error.") + strings.Contains(sErr, "go.uber.org/mock/mockgen/model") { + fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://go.uber.org/mock#reflect-vendoring-error.\n") return nil, err } return nil, err @@ -188,6 +187,7 @@ type reflectData struct { // gob encoding of a model.Package to standard output. // JSON doesn't work because of the model.Type interface. var reflectProgram = template.Must(template.New("program").Parse(` +// Code generated by MockGen. DO NOT EDIT. package main import ( @@ -198,7 +198,7 @@ import ( "path" "reflect" - "github.com/golang/mock/mockgen/model" + "go.uber.org/mock/mockgen/model" pkg_ {{printf "%q" .ImportPath}} ) diff --git a/vendor/github.com/golang/mock/mockgen/version.1.12.go b/vendor/go.uber.org/mock/mockgen/version.go similarity index 94% rename from vendor/github.com/golang/mock/mockgen/version.1.12.go rename to vendor/go.uber.org/mock/mockgen/version.go index ad121ae63..6db160ac2 100644 --- a/vendor/github.com/golang/mock/mockgen/version.1.12.go +++ b/vendor/go.uber.org/mock/mockgen/version.go @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,9 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// - -// +build go1.12 package main @@ -31,5 +28,4 @@ func printModuleVersion() { "GO111MODULE=on when running 'go get' in order to use specific " + "version of the binary.") } - } diff --git a/vendor/golang.org/x/crypto/cryptobyte/asn1.go b/vendor/golang.org/x/crypto/cryptobyte/asn1.go deleted file mode 100644 index 2492f796a..000000000 --- a/vendor/golang.org/x/crypto/cryptobyte/asn1.go +++ /dev/null @@ -1,825 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package cryptobyte - -import ( - encoding_asn1 "encoding/asn1" - "fmt" - "math/big" - "reflect" - "time" - - "golang.org/x/crypto/cryptobyte/asn1" -) - -// This file contains ASN.1-related methods for String and Builder. - -// Builder - -// AddASN1Int64 appends a DER-encoded ASN.1 INTEGER. -func (b *Builder) AddASN1Int64(v int64) { - b.addASN1Signed(asn1.INTEGER, v) -} - -// AddASN1Int64WithTag appends a DER-encoded ASN.1 INTEGER with the -// given tag. -func (b *Builder) AddASN1Int64WithTag(v int64, tag asn1.Tag) { - b.addASN1Signed(tag, v) -} - -// AddASN1Enum appends a DER-encoded ASN.1 ENUMERATION. -func (b *Builder) AddASN1Enum(v int64) { - b.addASN1Signed(asn1.ENUM, v) -} - -func (b *Builder) addASN1Signed(tag asn1.Tag, v int64) { - b.AddASN1(tag, func(c *Builder) { - length := 1 - for i := v; i >= 0x80 || i < -0x80; i >>= 8 { - length++ - } - - for ; length > 0; length-- { - i := v >> uint((length-1)*8) & 0xff - c.AddUint8(uint8(i)) - } - }) -} - -// AddASN1Uint64 appends a DER-encoded ASN.1 INTEGER. -func (b *Builder) AddASN1Uint64(v uint64) { - b.AddASN1(asn1.INTEGER, func(c *Builder) { - length := 1 - for i := v; i >= 0x80; i >>= 8 { - length++ - } - - for ; length > 0; length-- { - i := v >> uint((length-1)*8) & 0xff - c.AddUint8(uint8(i)) - } - }) -} - -// AddASN1BigInt appends a DER-encoded ASN.1 INTEGER. -func (b *Builder) AddASN1BigInt(n *big.Int) { - if b.err != nil { - return - } - - b.AddASN1(asn1.INTEGER, func(c *Builder) { - if n.Sign() < 0 { - // A negative number has to be converted to two's-complement form. So we - // invert and subtract 1. If the most-significant-bit isn't set then - // we'll need to pad the beginning with 0xff in order to keep the number - // negative. - nMinus1 := new(big.Int).Neg(n) - nMinus1.Sub(nMinus1, bigOne) - bytes := nMinus1.Bytes() - for i := range bytes { - bytes[i] ^= 0xff - } - if len(bytes) == 0 || bytes[0]&0x80 == 0 { - c.add(0xff) - } - c.add(bytes...) - } else if n.Sign() == 0 { - c.add(0) - } else { - bytes := n.Bytes() - if bytes[0]&0x80 != 0 { - c.add(0) - } - c.add(bytes...) - } - }) -} - -// AddASN1OctetString appends a DER-encoded ASN.1 OCTET STRING. -func (b *Builder) AddASN1OctetString(bytes []byte) { - b.AddASN1(asn1.OCTET_STRING, func(c *Builder) { - c.AddBytes(bytes) - }) -} - -const generalizedTimeFormatStr = "20060102150405Z0700" - -// AddASN1GeneralizedTime appends a DER-encoded ASN.1 GENERALIZEDTIME. -func (b *Builder) AddASN1GeneralizedTime(t time.Time) { - if t.Year() < 0 || t.Year() > 9999 { - b.err = fmt.Errorf("cryptobyte: cannot represent %v as a GeneralizedTime", t) - return - } - b.AddASN1(asn1.GeneralizedTime, func(c *Builder) { - c.AddBytes([]byte(t.Format(generalizedTimeFormatStr))) - }) -} - -// AddASN1UTCTime appends a DER-encoded ASN.1 UTCTime. -func (b *Builder) AddASN1UTCTime(t time.Time) { - b.AddASN1(asn1.UTCTime, func(c *Builder) { - // As utilized by the X.509 profile, UTCTime can only - // represent the years 1950 through 2049. - if t.Year() < 1950 || t.Year() >= 2050 { - b.err = fmt.Errorf("cryptobyte: cannot represent %v as a UTCTime", t) - return - } - c.AddBytes([]byte(t.Format(defaultUTCTimeFormatStr))) - }) -} - -// AddASN1BitString appends a DER-encoded ASN.1 BIT STRING. This does not -// support BIT STRINGs that are not a whole number of bytes. -func (b *Builder) AddASN1BitString(data []byte) { - b.AddASN1(asn1.BIT_STRING, func(b *Builder) { - b.AddUint8(0) - b.AddBytes(data) - }) -} - -func (b *Builder) addBase128Int(n int64) { - var length int - if n == 0 { - length = 1 - } else { - for i := n; i > 0; i >>= 7 { - length++ - } - } - - for i := length - 1; i >= 0; i-- { - o := byte(n >> uint(i*7)) - o &= 0x7f - if i != 0 { - o |= 0x80 - } - - b.add(o) - } -} - -func isValidOID(oid encoding_asn1.ObjectIdentifier) bool { - if len(oid) < 2 { - return false - } - - if oid[0] > 2 || (oid[0] <= 1 && oid[1] >= 40) { - return false - } - - for _, v := range oid { - if v < 0 { - return false - } - } - - return true -} - -func (b *Builder) AddASN1ObjectIdentifier(oid encoding_asn1.ObjectIdentifier) { - b.AddASN1(asn1.OBJECT_IDENTIFIER, func(b *Builder) { - if !isValidOID(oid) { - b.err = fmt.Errorf("cryptobyte: invalid OID: %v", oid) - return - } - - b.addBase128Int(int64(oid[0])*40 + int64(oid[1])) - for _, v := range oid[2:] { - b.addBase128Int(int64(v)) - } - }) -} - -func (b *Builder) AddASN1Boolean(v bool) { - b.AddASN1(asn1.BOOLEAN, func(b *Builder) { - if v { - b.AddUint8(0xff) - } else { - b.AddUint8(0) - } - }) -} - -func (b *Builder) AddASN1NULL() { - b.add(uint8(asn1.NULL), 0) -} - -// MarshalASN1 calls encoding_asn1.Marshal on its input and appends the result if -// successful or records an error if one occurred. -func (b *Builder) MarshalASN1(v interface{}) { - // NOTE(martinkr): This is somewhat of a hack to allow propagation of - // encoding_asn1.Marshal errors into Builder.err. N.B. if you call MarshalASN1 with a - // value embedded into a struct, its tag information is lost. - if b.err != nil { - return - } - bytes, err := encoding_asn1.Marshal(v) - if err != nil { - b.err = err - return - } - b.AddBytes(bytes) -} - -// AddASN1 appends an ASN.1 object. The object is prefixed with the given tag. -// Tags greater than 30 are not supported and result in an error (i.e. -// low-tag-number form only). The child builder passed to the -// BuilderContinuation can be used to build the content of the ASN.1 object. -func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) { - if b.err != nil { - return - } - // Identifiers with the low five bits set indicate high-tag-number format - // (two or more octets), which we don't support. - if tag&0x1f == 0x1f { - b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag) - return - } - b.AddUint8(uint8(tag)) - b.addLengthPrefixed(1, true, f) -} - -// String - -// ReadASN1Boolean decodes an ASN.1 BOOLEAN and converts it to a boolean -// representation into out and advances. It reports whether the read -// was successful. -func (s *String) ReadASN1Boolean(out *bool) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.BOOLEAN) || len(bytes) != 1 { - return false - } - - switch bytes[0] { - case 0: - *out = false - case 0xff: - *out = true - default: - return false - } - - return true -} - -// ReadASN1Integer decodes an ASN.1 INTEGER into out and advances. If out does -// not point to an integer, to a big.Int, or to a []byte it panics. Only -// positive and zero values can be decoded into []byte, and they are returned as -// big-endian binary values that share memory with s. Positive values will have -// no leading zeroes, and zero will be returned as a single zero byte. -// ReadASN1Integer reports whether the read was successful. -func (s *String) ReadASN1Integer(out interface{}) bool { - switch out := out.(type) { - case *int, *int8, *int16, *int32, *int64: - var i int64 - if !s.readASN1Int64(&i) || reflect.ValueOf(out).Elem().OverflowInt(i) { - return false - } - reflect.ValueOf(out).Elem().SetInt(i) - return true - case *uint, *uint8, *uint16, *uint32, *uint64: - var u uint64 - if !s.readASN1Uint64(&u) || reflect.ValueOf(out).Elem().OverflowUint(u) { - return false - } - reflect.ValueOf(out).Elem().SetUint(u) - return true - case *big.Int: - return s.readASN1BigInt(out) - case *[]byte: - return s.readASN1Bytes(out) - default: - panic("out does not point to an integer type") - } -} - -func checkASN1Integer(bytes []byte) bool { - if len(bytes) == 0 { - // An INTEGER is encoded with at least one octet. - return false - } - if len(bytes) == 1 { - return true - } - if bytes[0] == 0 && bytes[1]&0x80 == 0 || bytes[0] == 0xff && bytes[1]&0x80 == 0x80 { - // Value is not minimally encoded. - return false - } - return true -} - -var bigOne = big.NewInt(1) - -func (s *String) readASN1BigInt(out *big.Int) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) { - return false - } - if bytes[0]&0x80 == 0x80 { - // Negative number. - neg := make([]byte, len(bytes)) - for i, b := range bytes { - neg[i] = ^b - } - out.SetBytes(neg) - out.Add(out, bigOne) - out.Neg(out) - } else { - out.SetBytes(bytes) - } - return true -} - -func (s *String) readASN1Bytes(out *[]byte) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) { - return false - } - if bytes[0]&0x80 == 0x80 { - return false - } - for len(bytes) > 1 && bytes[0] == 0 { - bytes = bytes[1:] - } - *out = bytes - return true -} - -func (s *String) readASN1Int64(out *int64) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Signed(out, bytes) { - return false - } - return true -} - -func asn1Signed(out *int64, n []byte) bool { - length := len(n) - if length > 8 { - return false - } - for i := 0; i < length; i++ { - *out <<= 8 - *out |= int64(n[i]) - } - // Shift up and down in order to sign extend the result. - *out <<= 64 - uint8(length)*8 - *out >>= 64 - uint8(length)*8 - return true -} - -func (s *String) readASN1Uint64(out *uint64) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Unsigned(out, bytes) { - return false - } - return true -} - -func asn1Unsigned(out *uint64, n []byte) bool { - length := len(n) - if length > 9 || length == 9 && n[0] != 0 { - // Too large for uint64. - return false - } - if n[0]&0x80 != 0 { - // Negative number. - return false - } - for i := 0; i < length; i++ { - *out <<= 8 - *out |= uint64(n[i]) - } - return true -} - -// ReadASN1Int64WithTag decodes an ASN.1 INTEGER with the given tag into out -// and advances. It reports whether the read was successful and resulted in a -// value that can be represented in an int64. -func (s *String) ReadASN1Int64WithTag(out *int64, tag asn1.Tag) bool { - var bytes String - return s.ReadASN1(&bytes, tag) && checkASN1Integer(bytes) && asn1Signed(out, bytes) -} - -// ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It reports -// whether the read was successful. -func (s *String) ReadASN1Enum(out *int) bool { - var bytes String - var i int64 - if !s.ReadASN1(&bytes, asn1.ENUM) || !checkASN1Integer(bytes) || !asn1Signed(&i, bytes) { - return false - } - if int64(int(i)) != i { - return false - } - *out = int(i) - return true -} - -func (s *String) readBase128Int(out *int) bool { - ret := 0 - for i := 0; len(*s) > 0; i++ { - if i == 5 { - return false - } - // Avoid overflowing int on a 32-bit platform. - // We don't want different behavior based on the architecture. - if ret >= 1<<(31-7) { - return false - } - ret <<= 7 - b := s.read(1)[0] - - // ITU-T X.690, section 8.19.2: - // The subidentifier shall be encoded in the fewest possible octets, - // that is, the leading octet of the subidentifier shall not have the value 0x80. - if i == 0 && b == 0x80 { - return false - } - - ret |= int(b & 0x7f) - if b&0x80 == 0 { - *out = ret - return true - } - } - return false // truncated -} - -// ReadASN1ObjectIdentifier decodes an ASN.1 OBJECT IDENTIFIER into out and -// advances. It reports whether the read was successful. -func (s *String) ReadASN1ObjectIdentifier(out *encoding_asn1.ObjectIdentifier) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.OBJECT_IDENTIFIER) || len(bytes) == 0 { - return false - } - - // In the worst case, we get two elements from the first byte (which is - // encoded differently) and then every varint is a single byte long. - components := make([]int, len(bytes)+1) - - // The first varint is 40*value1 + value2: - // According to this packing, value1 can take the values 0, 1 and 2 only. - // When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2, - // then there are no restrictions on value2. - var v int - if !bytes.readBase128Int(&v) { - return false - } - if v < 80 { - components[0] = v / 40 - components[1] = v % 40 - } else { - components[0] = 2 - components[1] = v - 80 - } - - i := 2 - for ; len(bytes) > 0; i++ { - if !bytes.readBase128Int(&v) { - return false - } - components[i] = v - } - *out = components[:i] - return true -} - -// ReadASN1GeneralizedTime decodes an ASN.1 GENERALIZEDTIME into out and -// advances. It reports whether the read was successful. -func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.GeneralizedTime) { - return false - } - t := string(bytes) - res, err := time.Parse(generalizedTimeFormatStr, t) - if err != nil { - return false - } - if serialized := res.Format(generalizedTimeFormatStr); serialized != t { - return false - } - *out = res - return true -} - -const defaultUTCTimeFormatStr = "060102150405Z0700" - -// ReadASN1UTCTime decodes an ASN.1 UTCTime into out and advances. -// It reports whether the read was successful. -func (s *String) ReadASN1UTCTime(out *time.Time) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.UTCTime) { - return false - } - t := string(bytes) - - formatStr := defaultUTCTimeFormatStr - var err error - res, err := time.Parse(formatStr, t) - if err != nil { - // Fallback to minute precision if we can't parse second - // precision. If we are following X.509 or X.690 we shouldn't - // support this, but we do. - formatStr = "0601021504Z0700" - res, err = time.Parse(formatStr, t) - } - if err != nil { - return false - } - - if serialized := res.Format(formatStr); serialized != t { - return false - } - - if res.Year() >= 2050 { - // UTCTime interprets the low order digits 50-99 as 1950-99. - // This only applies to its use in the X.509 profile. - // See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1 - res = res.AddDate(-100, 0, 0) - } - *out = res - return true -} - -// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. -// It reports whether the read was successful. -func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool { - var bytes String - if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 || - len(bytes)*8/8 != len(bytes) { - return false - } - - paddingBits := bytes[0] - bytes = bytes[1:] - if paddingBits > 7 || - len(bytes) == 0 && paddingBits != 0 || - len(bytes) > 0 && bytes[len(bytes)-1]&(1< 4 || len(*s) < int(2+lenLen) { - return false - } - - lenBytes := String((*s)[2 : 2+lenLen]) - if !lenBytes.readUnsigned(&len32, int(lenLen)) { - return false - } - - // ITU-T X.690 section 10.1 (DER length forms) requires encoding the length - // with the minimum number of octets. - if len32 < 128 { - // Length should have used short-form encoding. - return false - } - if len32>>((lenLen-1)*8) == 0 { - // Leading octet is 0. Length should have been at least one byte shorter. - return false - } - - headerLen = 2 + uint32(lenLen) - if headerLen+len32 < len32 { - // Overflow. - return false - } - length = headerLen + len32 - } - - if int(length) < 0 || !s.ReadBytes((*[]byte)(out), int(length)) { - return false - } - if skipHeader && !out.Skip(int(headerLen)) { - panic("cryptobyte: internal error") - } - - return true -} diff --git a/vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go b/vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go deleted file mode 100644 index cda8e3edf..000000000 --- a/vendor/golang.org/x/crypto/cryptobyte/asn1/asn1.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package asn1 contains supporting types for parsing and building ASN.1 -// messages with the cryptobyte package. -package asn1 // import "golang.org/x/crypto/cryptobyte/asn1" - -// Tag represents an ASN.1 identifier octet, consisting of a tag number -// (indicating a type) and class (such as context-specific or constructed). -// -// Methods in the cryptobyte package only support the low-tag-number form, i.e. -// a single identifier octet with bits 7-8 encoding the class and bits 1-6 -// encoding the tag number. -type Tag uint8 - -const ( - classConstructed = 0x20 - classContextSpecific = 0x80 -) - -// Constructed returns t with the constructed class bit set. -func (t Tag) Constructed() Tag { return t | classConstructed } - -// ContextSpecific returns t with the context-specific class bit set. -func (t Tag) ContextSpecific() Tag { return t | classContextSpecific } - -// The following is a list of standard tag and class combinations. -const ( - BOOLEAN = Tag(1) - INTEGER = Tag(2) - BIT_STRING = Tag(3) - OCTET_STRING = Tag(4) - NULL = Tag(5) - OBJECT_IDENTIFIER = Tag(6) - ENUM = Tag(10) - UTF8String = Tag(12) - SEQUENCE = Tag(16 | classConstructed) - SET = Tag(17 | classConstructed) - PrintableString = Tag(19) - T61String = Tag(20) - IA5String = Tag(22) - UTCTime = Tag(23) - GeneralizedTime = Tag(24) - GeneralString = Tag(27) -) diff --git a/vendor/golang.org/x/crypto/cryptobyte/builder.go b/vendor/golang.org/x/crypto/cryptobyte/builder.go deleted file mode 100644 index cf254f5f1..000000000 --- a/vendor/golang.org/x/crypto/cryptobyte/builder.go +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package cryptobyte - -import ( - "errors" - "fmt" -) - -// A Builder builds byte strings from fixed-length and length-prefixed values. -// Builders either allocate space as needed, or are ‘fixed’, which means that -// they write into a given buffer and produce an error if it's exhausted. -// -// The zero value is a usable Builder that allocates space as needed. -// -// Simple values are marshaled and appended to a Builder using methods on the -// Builder. Length-prefixed values are marshaled by providing a -// BuilderContinuation, which is a function that writes the inner contents of -// the value to a given Builder. See the documentation for BuilderContinuation -// for details. -type Builder struct { - err error - result []byte - fixedSize bool - child *Builder - offset int - pendingLenLen int - pendingIsASN1 bool - inContinuation *bool -} - -// NewBuilder creates a Builder that appends its output to the given buffer. -// Like append(), the slice will be reallocated if its capacity is exceeded. -// Use Bytes to get the final buffer. -func NewBuilder(buffer []byte) *Builder { - return &Builder{ - result: buffer, - } -} - -// NewFixedBuilder creates a Builder that appends its output into the given -// buffer. This builder does not reallocate the output buffer. Writes that -// would exceed the buffer's capacity are treated as an error. -func NewFixedBuilder(buffer []byte) *Builder { - return &Builder{ - result: buffer, - fixedSize: true, - } -} - -// SetError sets the value to be returned as the error from Bytes. Writes -// performed after calling SetError are ignored. -func (b *Builder) SetError(err error) { - b.err = err -} - -// Bytes returns the bytes written by the builder or an error if one has -// occurred during building. -func (b *Builder) Bytes() ([]byte, error) { - if b.err != nil { - return nil, b.err - } - return b.result[b.offset:], nil -} - -// BytesOrPanic returns the bytes written by the builder or panics if an error -// has occurred during building. -func (b *Builder) BytesOrPanic() []byte { - if b.err != nil { - panic(b.err) - } - return b.result[b.offset:] -} - -// AddUint8 appends an 8-bit value to the byte string. -func (b *Builder) AddUint8(v uint8) { - b.add(byte(v)) -} - -// AddUint16 appends a big-endian, 16-bit value to the byte string. -func (b *Builder) AddUint16(v uint16) { - b.add(byte(v>>8), byte(v)) -} - -// AddUint24 appends a big-endian, 24-bit value to the byte string. The highest -// byte of the 32-bit input value is silently truncated. -func (b *Builder) AddUint24(v uint32) { - b.add(byte(v>>16), byte(v>>8), byte(v)) -} - -// AddUint32 appends a big-endian, 32-bit value to the byte string. -func (b *Builder) AddUint32(v uint32) { - b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// AddUint48 appends a big-endian, 48-bit value to the byte string. -func (b *Builder) AddUint48(v uint64) { - b.add(byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// AddUint64 appends a big-endian, 64-bit value to the byte string. -func (b *Builder) AddUint64(v uint64) { - b.add(byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// AddBytes appends a sequence of bytes to the byte string. -func (b *Builder) AddBytes(v []byte) { - b.add(v...) -} - -// BuilderContinuation is a continuation-passing interface for building -// length-prefixed byte sequences. Builder methods for length-prefixed -// sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation -// supplied to them. The child builder passed to the continuation can be used -// to build the content of the length-prefixed sequence. For example: -// -// parent := cryptobyte.NewBuilder() -// parent.AddUint8LengthPrefixed(func (child *Builder) { -// child.AddUint8(42) -// child.AddUint8LengthPrefixed(func (grandchild *Builder) { -// grandchild.AddUint8(5) -// }) -// }) -// -// It is an error to write more bytes to the child than allowed by the reserved -// length prefix. After the continuation returns, the child must be considered -// invalid, i.e. users must not store any copies or references of the child -// that outlive the continuation. -// -// If the continuation panics with a value of type BuildError then the inner -// error will be returned as the error from Bytes. If the child panics -// otherwise then Bytes will repanic with the same value. -type BuilderContinuation func(child *Builder) - -// BuildError wraps an error. If a BuilderContinuation panics with this value, -// the panic will be recovered and the inner error will be returned from -// Builder.Bytes. -type BuildError struct { - Err error -} - -// AddUint8LengthPrefixed adds a 8-bit length-prefixed byte sequence. -func (b *Builder) AddUint8LengthPrefixed(f BuilderContinuation) { - b.addLengthPrefixed(1, false, f) -} - -// AddUint16LengthPrefixed adds a big-endian, 16-bit length-prefixed byte sequence. -func (b *Builder) AddUint16LengthPrefixed(f BuilderContinuation) { - b.addLengthPrefixed(2, false, f) -} - -// AddUint24LengthPrefixed adds a big-endian, 24-bit length-prefixed byte sequence. -func (b *Builder) AddUint24LengthPrefixed(f BuilderContinuation) { - b.addLengthPrefixed(3, false, f) -} - -// AddUint32LengthPrefixed adds a big-endian, 32-bit length-prefixed byte sequence. -func (b *Builder) AddUint32LengthPrefixed(f BuilderContinuation) { - b.addLengthPrefixed(4, false, f) -} - -func (b *Builder) callContinuation(f BuilderContinuation, arg *Builder) { - if !*b.inContinuation { - *b.inContinuation = true - - defer func() { - *b.inContinuation = false - - r := recover() - if r == nil { - return - } - - if buildError, ok := r.(BuildError); ok { - b.err = buildError.Err - } else { - panic(r) - } - }() - } - - f(arg) -} - -func (b *Builder) addLengthPrefixed(lenLen int, isASN1 bool, f BuilderContinuation) { - // Subsequent writes can be ignored if the builder has encountered an error. - if b.err != nil { - return - } - - offset := len(b.result) - b.add(make([]byte, lenLen)...) - - if b.inContinuation == nil { - b.inContinuation = new(bool) - } - - b.child = &Builder{ - result: b.result, - fixedSize: b.fixedSize, - offset: offset, - pendingLenLen: lenLen, - pendingIsASN1: isASN1, - inContinuation: b.inContinuation, - } - - b.callContinuation(f, b.child) - b.flushChild() - if b.child != nil { - panic("cryptobyte: internal error") - } -} - -func (b *Builder) flushChild() { - if b.child == nil { - return - } - b.child.flushChild() - child := b.child - b.child = nil - - if child.err != nil { - b.err = child.err - return - } - - length := len(child.result) - child.pendingLenLen - child.offset - - if length < 0 { - panic("cryptobyte: internal error") // result unexpectedly shrunk - } - - if child.pendingIsASN1 { - // For ASN.1, we reserved a single byte for the length. If that turned out - // to be incorrect, we have to move the contents along in order to make - // space. - if child.pendingLenLen != 1 { - panic("cryptobyte: internal error") - } - var lenLen, lenByte uint8 - if int64(length) > 0xfffffffe { - b.err = errors.New("pending ASN.1 child too long") - return - } else if length > 0xffffff { - lenLen = 5 - lenByte = 0x80 | 4 - } else if length > 0xffff { - lenLen = 4 - lenByte = 0x80 | 3 - } else if length > 0xff { - lenLen = 3 - lenByte = 0x80 | 2 - } else if length > 0x7f { - lenLen = 2 - lenByte = 0x80 | 1 - } else { - lenLen = 1 - lenByte = uint8(length) - length = 0 - } - - // Insert the initial length byte, make space for successive length bytes, - // and adjust the offset. - child.result[child.offset] = lenByte - extraBytes := int(lenLen - 1) - if extraBytes != 0 { - child.add(make([]byte, extraBytes)...) - childStart := child.offset + child.pendingLenLen - copy(child.result[childStart+extraBytes:], child.result[childStart:]) - } - child.offset++ - child.pendingLenLen = extraBytes - } - - l := length - for i := child.pendingLenLen - 1; i >= 0; i-- { - child.result[child.offset+i] = uint8(l) - l >>= 8 - } - if l != 0 { - b.err = fmt.Errorf("cryptobyte: pending child length %d exceeds %d-byte length prefix", length, child.pendingLenLen) - return - } - - if b.fixedSize && &b.result[0] != &child.result[0] { - panic("cryptobyte: BuilderContinuation reallocated a fixed-size buffer") - } - - b.result = child.result -} - -func (b *Builder) add(bytes ...byte) { - if b.err != nil { - return - } - if b.child != nil { - panic("cryptobyte: attempted write while child is pending") - } - if len(b.result)+len(bytes) < len(bytes) { - b.err = errors.New("cryptobyte: length overflow") - } - if b.fixedSize && len(b.result)+len(bytes) > cap(b.result) { - b.err = errors.New("cryptobyte: Builder is exceeding its fixed-size buffer") - return - } - b.result = append(b.result, bytes...) -} - -// Unwrite rolls back non-negative n bytes written directly to the Builder. -// An attempt by a child builder passed to a continuation to unwrite bytes -// from its parent will panic. -func (b *Builder) Unwrite(n int) { - if b.err != nil { - return - } - if b.child != nil { - panic("cryptobyte: attempted unwrite while child is pending") - } - length := len(b.result) - b.pendingLenLen - b.offset - if length < 0 { - panic("cryptobyte: internal error") - } - if n < 0 { - panic("cryptobyte: attempted to unwrite negative number of bytes") - } - if n > length { - panic("cryptobyte: attempted to unwrite more than was written") - } - b.result = b.result[:len(b.result)-n] -} - -// A MarshalingValue marshals itself into a Builder. -type MarshalingValue interface { - // Marshal is called by Builder.AddValue. It receives a pointer to a builder - // to marshal itself into. It may return an error that occurred during - // marshaling, such as unset or invalid values. - Marshal(b *Builder) error -} - -// AddValue calls Marshal on v, passing a pointer to the builder to append to. -// If Marshal returns an error, it is set on the Builder so that subsequent -// appends don't have an effect. -func (b *Builder) AddValue(v MarshalingValue) { - err := v.Marshal(b) - if err != nil { - b.err = err - } -} diff --git a/vendor/golang.org/x/crypto/cryptobyte/string.go b/vendor/golang.org/x/crypto/cryptobyte/string.go deleted file mode 100644 index 10692a8a3..000000000 --- a/vendor/golang.org/x/crypto/cryptobyte/string.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package cryptobyte contains types that help with parsing and constructing -// length-prefixed, binary messages, including ASN.1 DER. (The asn1 subpackage -// contains useful ASN.1 constants.) -// -// The String type is for parsing. It wraps a []byte slice and provides helper -// functions for consuming structures, value by value. -// -// The Builder type is for constructing messages. It providers helper functions -// for appending values and also for appending length-prefixed submessages – -// without having to worry about calculating the length prefix ahead of time. -// -// See the documentation and examples for the Builder and String types to get -// started. -package cryptobyte // import "golang.org/x/crypto/cryptobyte" - -// String represents a string of bytes. It provides methods for parsing -// fixed-length and length-prefixed values from it. -type String []byte - -// read advances a String by n bytes and returns them. If less than n bytes -// remain, it returns nil. -func (s *String) read(n int) []byte { - if len(*s) < n || n < 0 { - return nil - } - v := (*s)[:n] - *s = (*s)[n:] - return v -} - -// Skip advances the String by n byte and reports whether it was successful. -func (s *String) Skip(n int) bool { - return s.read(n) != nil -} - -// ReadUint8 decodes an 8-bit value into out and advances over it. -// It reports whether the read was successful. -func (s *String) ReadUint8(out *uint8) bool { - v := s.read(1) - if v == nil { - return false - } - *out = uint8(v[0]) - return true -} - -// ReadUint16 decodes a big-endian, 16-bit value into out and advances over it. -// It reports whether the read was successful. -func (s *String) ReadUint16(out *uint16) bool { - v := s.read(2) - if v == nil { - return false - } - *out = uint16(v[0])<<8 | uint16(v[1]) - return true -} - -// ReadUint24 decodes a big-endian, 24-bit value into out and advances over it. -// It reports whether the read was successful. -func (s *String) ReadUint24(out *uint32) bool { - v := s.read(3) - if v == nil { - return false - } - *out = uint32(v[0])<<16 | uint32(v[1])<<8 | uint32(v[2]) - return true -} - -// ReadUint32 decodes a big-endian, 32-bit value into out and advances over it. -// It reports whether the read was successful. -func (s *String) ReadUint32(out *uint32) bool { - v := s.read(4) - if v == nil { - return false - } - *out = uint32(v[0])<<24 | uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3]) - return true -} - -// ReadUint48 decodes a big-endian, 48-bit value into out and advances over it. -// It reports whether the read was successful. -func (s *String) ReadUint48(out *uint64) bool { - v := s.read(6) - if v == nil { - return false - } - *out = uint64(v[0])<<40 | uint64(v[1])<<32 | uint64(v[2])<<24 | uint64(v[3])<<16 | uint64(v[4])<<8 | uint64(v[5]) - return true -} - -// ReadUint64 decodes a big-endian, 64-bit value into out and advances over it. -// It reports whether the read was successful. -func (s *String) ReadUint64(out *uint64) bool { - v := s.read(8) - if v == nil { - return false - } - *out = uint64(v[0])<<56 | uint64(v[1])<<48 | uint64(v[2])<<40 | uint64(v[3])<<32 | uint64(v[4])<<24 | uint64(v[5])<<16 | uint64(v[6])<<8 | uint64(v[7]) - return true -} - -func (s *String) readUnsigned(out *uint32, length int) bool { - v := s.read(length) - if v == nil { - return false - } - var result uint32 - for i := 0; i < length; i++ { - result <<= 8 - result |= uint32(v[i]) - } - *out = result - return true -} - -func (s *String) readLengthPrefixed(lenLen int, outChild *String) bool { - lenBytes := s.read(lenLen) - if lenBytes == nil { - return false - } - var length uint32 - for _, b := range lenBytes { - length = length << 8 - length = length | uint32(b) - } - v := s.read(int(length)) - if v == nil { - return false - } - *outChild = v - return true -} - -// ReadUint8LengthPrefixed reads the content of an 8-bit length-prefixed value -// into out and advances over it. It reports whether the read was successful. -func (s *String) ReadUint8LengthPrefixed(out *String) bool { - return s.readLengthPrefixed(1, out) -} - -// ReadUint16LengthPrefixed reads the content of a big-endian, 16-bit -// length-prefixed value into out and advances over it. It reports whether the -// read was successful. -func (s *String) ReadUint16LengthPrefixed(out *String) bool { - return s.readLengthPrefixed(2, out) -} - -// ReadUint24LengthPrefixed reads the content of a big-endian, 24-bit -// length-prefixed value into out and advances over it. It reports whether -// the read was successful. -func (s *String) ReadUint24LengthPrefixed(out *String) bool { - return s.readLengthPrefixed(3, out) -} - -// ReadBytes reads n bytes into out and advances over them. It reports -// whether the read was successful. -func (s *String) ReadBytes(out *[]byte, n int) bool { - v := s.read(n) - if v == nil { - return false - } - *out = v - return true -} - -// CopyBytes copies len(out) bytes into out and advances over them. It reports -// whether the copy operation was successful -func (s *String) CopyBytes(out []byte) bool { - n := len(out) - v := s.read(n) - if v == nil { - return false - } - return copy(out, v) == n -} - -// Empty reports whether the string does not contain any bytes. -func (s String) Empty() bool { - return len(s) == 0 -} diff --git a/vendor/golang.org/x/exp/constraints/constraints.go b/vendor/golang.org/x/exp/constraints/constraints.go deleted file mode 100644 index 2c033dff4..000000000 --- a/vendor/golang.org/x/exp/constraints/constraints.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package constraints defines a set of useful constraints to be used -// with type parameters. -package constraints - -// Signed is a constraint that permits any signed integer type. -// If future releases of Go add new predeclared signed integer types, -// this constraint will be modified to include them. -type Signed interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 -} - -// Unsigned is a constraint that permits any unsigned integer type. -// If future releases of Go add new predeclared unsigned integer types, -// this constraint will be modified to include them. -type Unsigned interface { - ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr -} - -// Integer is a constraint that permits any integer type. -// If future releases of Go add new predeclared integer types, -// this constraint will be modified to include them. -type Integer interface { - Signed | Unsigned -} - -// Float is a constraint that permits any floating-point type. -// If future releases of Go add new predeclared floating-point types, -// this constraint will be modified to include them. -type Float interface { - ~float32 | ~float64 -} - -// Complex is a constraint that permits any complex numeric type. -// If future releases of Go add new predeclared complex numeric types, -// this constraint will be modified to include them. -type Complex interface { - ~complex64 | ~complex128 -} - -// Ordered is a constraint that permits any ordered type: any type -// that supports the operators < <= >= >. -// If future releases of Go add new ordered types, -// this constraint will be modified to include them. -type Ordered interface { - Integer | Float | ~string -} diff --git a/vendor/golang.org/x/exp/rand/exp.go b/vendor/golang.org/x/exp/rand/exp.go new file mode 100644 index 000000000..083867276 --- /dev/null +++ b/vendor/golang.org/x/exp/rand/exp.go @@ -0,0 +1,221 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +import ( + "math" +) + +/* + * Exponential distribution + * + * See "The Ziggurat Method for Generating Random Variables" + * (Marsaglia & Tsang, 2000) + * http://www.jstatsoft.org/v05/i08/paper [pdf] + */ + +const ( + re = 7.69711747013104972 +) + +// ExpFloat64 returns an exponentially distributed float64 in the range +// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter +// (lambda) is 1 and whose mean is 1/lambda (1). +// To produce a distribution with a different rate parameter, +// callers can adjust the output using: +// +// sample = ExpFloat64() / desiredRateParameter +func (r *Rand) ExpFloat64() float64 { + for { + j := r.Uint32() + i := j & 0xFF + x := float64(j) * float64(we[i]) + if j < ke[i] { + return x + } + if i == 0 { + return re - math.Log(r.Float64()) + } + if fe[i]+float32(r.Float64())*(fe[i-1]-fe[i]) < float32(math.Exp(-x)) { + return x + } + } +} + +var ke = [256]uint32{ + 0xe290a139, 0x0, 0x9beadebc, 0xc377ac71, 0xd4ddb990, + 0xde893fb8, 0xe4a8e87c, 0xe8dff16a, 0xebf2deab, 0xee49a6e8, + 0xf0204efd, 0xf19bdb8e, 0xf2d458bb, 0xf3da104b, 0xf4b86d78, + 0xf577ad8a, 0xf61de83d, 0xf6afb784, 0xf730a573, 0xf7a37651, + 0xf80a5bb6, 0xf867189d, 0xf8bb1b4f, 0xf9079062, 0xf94d70ca, + 0xf98d8c7d, 0xf9c8928a, 0xf9ff175b, 0xfa319996, 0xfa6085f8, + 0xfa8c3a62, 0xfab5084e, 0xfadb36c8, 0xfaff0410, 0xfb20a6ea, + 0xfb404fb4, 0xfb5e2951, 0xfb7a59e9, 0xfb95038c, 0xfbae44ba, + 0xfbc638d8, 0xfbdcf892, 0xfbf29a30, 0xfc0731df, 0xfc1ad1ed, + 0xfc2d8b02, 0xfc3f6c4d, 0xfc5083ac, 0xfc60ddd1, 0xfc708662, + 0xfc7f8810, 0xfc8decb4, 0xfc9bbd62, 0xfca9027c, 0xfcb5c3c3, + 0xfcc20864, 0xfccdd70a, 0xfcd935e3, 0xfce42ab0, 0xfceebace, + 0xfcf8eb3b, 0xfd02c0a0, 0xfd0c3f59, 0xfd156b7b, 0xfd1e48d6, + 0xfd26daff, 0xfd2f2552, 0xfd372af7, 0xfd3eeee5, 0xfd4673e7, + 0xfd4dbc9e, 0xfd54cb85, 0xfd5ba2f2, 0xfd62451b, 0xfd68b415, + 0xfd6ef1da, 0xfd750047, 0xfd7ae120, 0xfd809612, 0xfd8620b4, + 0xfd8b8285, 0xfd90bcf5, 0xfd95d15e, 0xfd9ac10b, 0xfd9f8d36, + 0xfda43708, 0xfda8bf9e, 0xfdad2806, 0xfdb17141, 0xfdb59c46, + 0xfdb9a9fd, 0xfdbd9b46, 0xfdc170f6, 0xfdc52bd8, 0xfdc8ccac, + 0xfdcc542d, 0xfdcfc30b, 0xfdd319ef, 0xfdd6597a, 0xfdd98245, + 0xfddc94e5, 0xfddf91e6, 0xfde279ce, 0xfde54d1f, 0xfde80c52, + 0xfdeab7de, 0xfded5034, 0xfdefd5be, 0xfdf248e3, 0xfdf4aa06, + 0xfdf6f984, 0xfdf937b6, 0xfdfb64f4, 0xfdfd818d, 0xfdff8dd0, + 0xfe018a08, 0xfe03767a, 0xfe05536c, 0xfe07211c, 0xfe08dfc9, + 0xfe0a8fab, 0xfe0c30fb, 0xfe0dc3ec, 0xfe0f48b1, 0xfe10bf76, + 0xfe122869, 0xfe1383b4, 0xfe14d17c, 0xfe1611e7, 0xfe174516, + 0xfe186b2a, 0xfe19843e, 0xfe1a9070, 0xfe1b8fd6, 0xfe1c8289, + 0xfe1d689b, 0xfe1e4220, 0xfe1f0f26, 0xfe1fcfbc, 0xfe2083ed, + 0xfe212bc3, 0xfe21c745, 0xfe225678, 0xfe22d95f, 0xfe234ffb, + 0xfe23ba4a, 0xfe241849, 0xfe2469f2, 0xfe24af3c, 0xfe24e81e, + 0xfe25148b, 0xfe253474, 0xfe2547c7, 0xfe254e70, 0xfe25485a, + 0xfe25356a, 0xfe251586, 0xfe24e88f, 0xfe24ae64, 0xfe2466e1, + 0xfe2411df, 0xfe23af34, 0xfe233eb4, 0xfe22c02c, 0xfe22336b, + 0xfe219838, 0xfe20ee58, 0xfe20358c, 0xfe1f6d92, 0xfe1e9621, + 0xfe1daef0, 0xfe1cb7ac, 0xfe1bb002, 0xfe1a9798, 0xfe196e0d, + 0xfe1832fd, 0xfe16e5fe, 0xfe15869d, 0xfe141464, 0xfe128ed3, + 0xfe10f565, 0xfe0f478c, 0xfe0d84b1, 0xfe0bac36, 0xfe09bd73, + 0xfe07b7b5, 0xfe059a40, 0xfe03644c, 0xfe011504, 0xfdfeab88, + 0xfdfc26e9, 0xfdf98629, 0xfdf6c83b, 0xfdf3ec01, 0xfdf0f04a, + 0xfdedd3d1, 0xfdea953d, 0xfde7331e, 0xfde3abe9, 0xfddffdfb, + 0xfddc2791, 0xfdd826cd, 0xfdd3f9a8, 0xfdcf9dfc, 0xfdcb1176, + 0xfdc65198, 0xfdc15bb3, 0xfdbc2ce2, 0xfdb6c206, 0xfdb117be, + 0xfdab2a63, 0xfda4f5fd, 0xfd9e7640, 0xfd97a67a, 0xfd908192, + 0xfd8901f2, 0xfd812182, 0xfd78d98e, 0xfd7022bb, 0xfd66f4ed, + 0xfd5d4732, 0xfd530f9c, 0xfd48432b, 0xfd3cd59a, 0xfd30b936, + 0xfd23dea4, 0xfd16349e, 0xfd07a7a3, 0xfcf8219b, 0xfce7895b, + 0xfcd5c220, 0xfcc2aadb, 0xfcae1d5e, 0xfc97ed4e, 0xfc7fe6d4, + 0xfc65ccf3, 0xfc495762, 0xfc2a2fc8, 0xfc07ee19, 0xfbe213c1, + 0xfbb8051a, 0xfb890078, 0xfb5411a5, 0xfb180005, 0xfad33482, + 0xfa839276, 0xfa263b32, 0xf9b72d1c, 0xf930a1a2, 0xf889f023, + 0xf7b577d2, 0xf69c650c, 0xf51530f0, 0xf2cb0e3c, 0xeeefb15d, + 0xe6da6ecf, +} +var we = [256]float32{ + 2.0249555e-09, 1.486674e-11, 2.4409617e-11, 3.1968806e-11, + 3.844677e-11, 4.4228204e-11, 4.9516443e-11, 5.443359e-11, + 5.905944e-11, 6.344942e-11, 6.7643814e-11, 7.1672945e-11, + 7.556032e-11, 7.932458e-11, 8.298079e-11, 8.654132e-11, + 9.0016515e-11, 9.3415074e-11, 9.674443e-11, 1.0001099e-10, + 1.03220314e-10, 1.06377254e-10, 1.09486115e-10, 1.1255068e-10, + 1.1557435e-10, 1.1856015e-10, 1.2151083e-10, 1.2442886e-10, + 1.2731648e-10, 1.3017575e-10, 1.3300853e-10, 1.3581657e-10, + 1.3860142e-10, 1.4136457e-10, 1.4410738e-10, 1.4683108e-10, + 1.4953687e-10, 1.5222583e-10, 1.54899e-10, 1.5755733e-10, + 1.6020171e-10, 1.6283301e-10, 1.6545203e-10, 1.6805951e-10, + 1.7065617e-10, 1.732427e-10, 1.7581973e-10, 1.7838787e-10, + 1.8094774e-10, 1.8349985e-10, 1.8604476e-10, 1.8858298e-10, + 1.9111498e-10, 1.9364126e-10, 1.9616223e-10, 1.9867835e-10, + 2.0119004e-10, 2.0369768e-10, 2.0620168e-10, 2.087024e-10, + 2.1120022e-10, 2.136955e-10, 2.1618855e-10, 2.1867974e-10, + 2.2116936e-10, 2.2365775e-10, 2.261452e-10, 2.2863202e-10, + 2.311185e-10, 2.3360494e-10, 2.360916e-10, 2.3857874e-10, + 2.4106667e-10, 2.4355562e-10, 2.4604588e-10, 2.485377e-10, + 2.5103128e-10, 2.5352695e-10, 2.560249e-10, 2.585254e-10, + 2.6102867e-10, 2.6353494e-10, 2.6604446e-10, 2.6855745e-10, + 2.7107416e-10, 2.7359479e-10, 2.761196e-10, 2.7864877e-10, + 2.8118255e-10, 2.8372119e-10, 2.8626485e-10, 2.888138e-10, + 2.9136826e-10, 2.939284e-10, 2.9649452e-10, 2.9906677e-10, + 3.016454e-10, 3.0423064e-10, 3.0682268e-10, 3.0942177e-10, + 3.1202813e-10, 3.1464195e-10, 3.1726352e-10, 3.19893e-10, + 3.2253064e-10, 3.251767e-10, 3.2783135e-10, 3.3049485e-10, + 3.3316744e-10, 3.3584938e-10, 3.3854083e-10, 3.4124212e-10, + 3.4395342e-10, 3.46675e-10, 3.4940711e-10, 3.5215003e-10, + 3.5490397e-10, 3.5766917e-10, 3.6044595e-10, 3.6323455e-10, + 3.660352e-10, 3.6884823e-10, 3.7167386e-10, 3.745124e-10, + 3.773641e-10, 3.802293e-10, 3.8310827e-10, 3.860013e-10, + 3.8890866e-10, 3.918307e-10, 3.9476775e-10, 3.9772008e-10, + 4.0068804e-10, 4.0367196e-10, 4.0667217e-10, 4.09689e-10, + 4.1272286e-10, 4.1577405e-10, 4.1884296e-10, 4.2192994e-10, + 4.250354e-10, 4.281597e-10, 4.313033e-10, 4.3446652e-10, + 4.3764986e-10, 4.408537e-10, 4.4407847e-10, 4.4732465e-10, + 4.5059267e-10, 4.5388301e-10, 4.571962e-10, 4.6053267e-10, + 4.6389292e-10, 4.6727755e-10, 4.70687e-10, 4.741219e-10, + 4.7758275e-10, 4.810702e-10, 4.845848e-10, 4.8812715e-10, + 4.9169796e-10, 4.9529775e-10, 4.989273e-10, 5.0258725e-10, + 5.0627835e-10, 5.100013e-10, 5.1375687e-10, 5.1754584e-10, + 5.21369e-10, 5.2522725e-10, 5.2912136e-10, 5.330522e-10, + 5.370208e-10, 5.4102806e-10, 5.45075e-10, 5.491625e-10, + 5.532918e-10, 5.5746385e-10, 5.616799e-10, 5.6594107e-10, + 5.7024857e-10, 5.746037e-10, 5.7900773e-10, 5.834621e-10, + 5.8796823e-10, 5.925276e-10, 5.971417e-10, 6.018122e-10, + 6.065408e-10, 6.113292e-10, 6.1617933e-10, 6.2109295e-10, + 6.260722e-10, 6.3111916e-10, 6.3623595e-10, 6.4142497e-10, + 6.4668854e-10, 6.5202926e-10, 6.5744976e-10, 6.6295286e-10, + 6.6854156e-10, 6.742188e-10, 6.79988e-10, 6.858526e-10, + 6.9181616e-10, 6.978826e-10, 7.04056e-10, 7.103407e-10, + 7.167412e-10, 7.2326256e-10, 7.2990985e-10, 7.366886e-10, + 7.4360473e-10, 7.5066453e-10, 7.5787476e-10, 7.6524265e-10, + 7.7277595e-10, 7.80483e-10, 7.883728e-10, 7.9645507e-10, + 8.047402e-10, 8.1323964e-10, 8.219657e-10, 8.309319e-10, + 8.401528e-10, 8.496445e-10, 8.594247e-10, 8.6951274e-10, + 8.799301e-10, 8.9070046e-10, 9.018503e-10, 9.134092e-10, + 9.254101e-10, 9.378904e-10, 9.508923e-10, 9.644638e-10, + 9.786603e-10, 9.935448e-10, 1.0091913e-09, 1.025686e-09, + 1.0431306e-09, 1.0616465e-09, 1.08138e-09, 1.1025096e-09, + 1.1252564e-09, 1.1498986e-09, 1.1767932e-09, 1.206409e-09, + 1.2393786e-09, 1.276585e-09, 1.3193139e-09, 1.3695435e-09, + 1.4305498e-09, 1.508365e-09, 1.6160854e-09, 1.7921248e-09, +} +var fe = [256]float32{ + 1, 0.9381437, 0.90046996, 0.87170434, 0.8477855, 0.8269933, + 0.8084217, 0.7915276, 0.77595687, 0.7614634, 0.7478686, + 0.7350381, 0.72286767, 0.71127474, 0.70019263, 0.6895665, + 0.67935055, 0.6695063, 0.66000086, 0.65080583, 0.6418967, + 0.63325197, 0.6248527, 0.6166822, 0.60872537, 0.60096896, + 0.5934009, 0.58601034, 0.5787874, 0.57172304, 0.5648092, + 0.5580383, 0.5514034, 0.5448982, 0.5385169, 0.53225386, + 0.5261042, 0.52006316, 0.5141264, 0.50828975, 0.5025495, + 0.496902, 0.49134386, 0.485872, 0.48048335, 0.4751752, + 0.46994483, 0.46478975, 0.45970762, 0.45469615, 0.44975325, + 0.44487688, 0.44006512, 0.43531612, 0.43062815, 0.42599955, + 0.42142874, 0.4169142, 0.41245446, 0.40804818, 0.403694, + 0.3993907, 0.39513698, 0.39093173, 0.38677382, 0.38266218, + 0.37859577, 0.37457356, 0.37059465, 0.3666581, 0.362763, + 0.35890847, 0.35509375, 0.351318, 0.3475805, 0.34388044, + 0.34021714, 0.3365899, 0.33299807, 0.32944095, 0.32591796, + 0.3224285, 0.3189719, 0.31554767, 0.31215525, 0.30879408, + 0.3054636, 0.3021634, 0.29889292, 0.2956517, 0.29243928, + 0.28925523, 0.28609908, 0.28297043, 0.27986884, 0.27679393, + 0.2737453, 0.2707226, 0.2677254, 0.26475343, 0.26180625, + 0.25888354, 0.25598502, 0.2531103, 0.25025907, 0.24743107, + 0.24462597, 0.24184346, 0.23908329, 0.23634516, 0.23362878, + 0.23093392, 0.2282603, 0.22560766, 0.22297576, 0.22036438, + 0.21777324, 0.21520215, 0.21265087, 0.21011916, 0.20760682, + 0.20511365, 0.20263945, 0.20018397, 0.19774707, 0.19532852, + 0.19292815, 0.19054577, 0.1881812, 0.18583426, 0.18350479, + 0.1811926, 0.17889754, 0.17661946, 0.17435817, 0.17211354, + 0.1698854, 0.16767362, 0.16547804, 0.16329853, 0.16113494, + 0.15898713, 0.15685499, 0.15473837, 0.15263714, 0.15055119, + 0.14848037, 0.14642459, 0.14438373, 0.14235765, 0.14034624, + 0.13834943, 0.13636707, 0.13439907, 0.13244532, 0.13050574, + 0.1285802, 0.12666863, 0.12477092, 0.12288698, 0.12101672, + 0.119160056, 0.1173169, 0.115487166, 0.11367077, 0.11186763, + 0.11007768, 0.10830083, 0.10653701, 0.10478614, 0.10304816, + 0.101323, 0.09961058, 0.09791085, 0.09622374, 0.09454919, + 0.09288713, 0.091237515, 0.08960028, 0.087975375, 0.08636274, + 0.08476233, 0.083174095, 0.081597984, 0.08003395, 0.07848195, + 0.076941945, 0.07541389, 0.07389775, 0.072393484, 0.07090106, + 0.069420435, 0.06795159, 0.066494495, 0.06504912, 0.063615434, + 0.062193416, 0.060783047, 0.059384305, 0.057997175, + 0.05662164, 0.05525769, 0.053905312, 0.052564494, 0.051235236, + 0.049917534, 0.048611384, 0.047316793, 0.046033762, 0.0447623, + 0.043502413, 0.042254124, 0.041017443, 0.039792392, + 0.038578995, 0.037377283, 0.036187284, 0.035009038, + 0.033842582, 0.032687962, 0.031545233, 0.030414443, 0.02929566, + 0.02818895, 0.027094385, 0.026012046, 0.024942026, 0.023884421, + 0.022839336, 0.021806888, 0.020787204, 0.019780423, 0.0187867, + 0.0178062, 0.016839107, 0.015885621, 0.014945968, 0.014020392, + 0.013109165, 0.012212592, 0.011331013, 0.01046481, 0.009614414, + 0.008780315, 0.007963077, 0.0071633533, 0.006381906, + 0.0056196423, 0.0048776558, 0.004157295, 0.0034602648, + 0.0027887989, 0.0021459677, 0.0015362998, 0.0009672693, + 0.00045413437, +} diff --git a/vendor/golang.org/x/exp/rand/normal.go b/vendor/golang.org/x/exp/rand/normal.go new file mode 100644 index 000000000..b66da3a81 --- /dev/null +++ b/vendor/golang.org/x/exp/rand/normal.go @@ -0,0 +1,156 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +import ( + "math" +) + +/* + * Normal distribution + * + * See "The Ziggurat Method for Generating Random Variables" + * (Marsaglia & Tsang, 2000) + * http://www.jstatsoft.org/v05/i08/paper [pdf] + */ + +const ( + rn = 3.442619855899 +) + +func absInt32(i int32) uint32 { + if i < 0 { + return uint32(-i) + } + return uint32(i) +} + +// NormFloat64 returns a normally distributed float64 in the range +// [-math.MaxFloat64, +math.MaxFloat64] with +// standard normal distribution (mean = 0, stddev = 1). +// To produce a different normal distribution, callers can +// adjust the output using: +// +// sample = NormFloat64() * desiredStdDev + desiredMean +func (r *Rand) NormFloat64() float64 { + for { + j := int32(r.Uint32()) // Possibly negative + i := j & 0x7F + x := float64(j) * float64(wn[i]) + if absInt32(j) < kn[i] { + // This case should be hit better than 99% of the time. + return x + } + + if i == 0 { + // This extra work is only required for the base strip. + for { + x = -math.Log(r.Float64()) * (1.0 / rn) + y := -math.Log(r.Float64()) + if y+y >= x*x { + break + } + } + if j > 0 { + return rn + x + } + return -rn - x + } + if fn[i]+float32(r.Float64())*(fn[i-1]-fn[i]) < float32(math.Exp(-.5*x*x)) { + return x + } + } +} + +var kn = [128]uint32{ + 0x76ad2212, 0x0, 0x600f1b53, 0x6ce447a6, 0x725b46a2, + 0x7560051d, 0x774921eb, 0x789a25bd, 0x799045c3, 0x7a4bce5d, + 0x7adf629f, 0x7b5682a6, 0x7bb8a8c6, 0x7c0ae722, 0x7c50cce7, + 0x7c8cec5b, 0x7cc12cd6, 0x7ceefed2, 0x7d177e0b, 0x7d3b8883, + 0x7d5bce6c, 0x7d78dd64, 0x7d932886, 0x7dab0e57, 0x7dc0dd30, + 0x7dd4d688, 0x7de73185, 0x7df81cea, 0x7e07c0a3, 0x7e163efa, + 0x7e23b587, 0x7e303dfd, 0x7e3beec2, 0x7e46db77, 0x7e51155d, + 0x7e5aabb3, 0x7e63abf7, 0x7e6c222c, 0x7e741906, 0x7e7b9a18, + 0x7e82adfa, 0x7e895c63, 0x7e8fac4b, 0x7e95a3fb, 0x7e9b4924, + 0x7ea0a0ef, 0x7ea5b00d, 0x7eaa7ac3, 0x7eaf04f3, 0x7eb3522a, + 0x7eb765a5, 0x7ebb4259, 0x7ebeeafd, 0x7ec2620a, 0x7ec5a9c4, + 0x7ec8c441, 0x7ecbb365, 0x7ece78ed, 0x7ed11671, 0x7ed38d62, + 0x7ed5df12, 0x7ed80cb4, 0x7eda175c, 0x7edc0005, 0x7eddc78e, + 0x7edf6ebf, 0x7ee0f647, 0x7ee25ebe, 0x7ee3a8a9, 0x7ee4d473, + 0x7ee5e276, 0x7ee6d2f5, 0x7ee7a620, 0x7ee85c10, 0x7ee8f4cd, + 0x7ee97047, 0x7ee9ce59, 0x7eea0eca, 0x7eea3147, 0x7eea3568, + 0x7eea1aab, 0x7ee9e071, 0x7ee98602, 0x7ee90a88, 0x7ee86d08, + 0x7ee7ac6a, 0x7ee6c769, 0x7ee5bc9c, 0x7ee48a67, 0x7ee32efc, + 0x7ee1a857, 0x7edff42f, 0x7ede0ffa, 0x7edbf8d9, 0x7ed9ab94, + 0x7ed7248d, 0x7ed45fae, 0x7ed1585c, 0x7ece095f, 0x7eca6ccb, + 0x7ec67be2, 0x7ec22eee, 0x7ebd7d1a, 0x7eb85c35, 0x7eb2c075, + 0x7eac9c20, 0x7ea5df27, 0x7e9e769f, 0x7e964c16, 0x7e8d44ba, + 0x7e834033, 0x7e781728, 0x7e6b9933, 0x7e5d8a1a, 0x7e4d9ded, + 0x7e3b737a, 0x7e268c2f, 0x7e0e3ff5, 0x7df1aa5d, 0x7dcf8c72, + 0x7da61a1e, 0x7d72a0fb, 0x7d30e097, 0x7cd9b4ab, 0x7c600f1a, + 0x7ba90bdc, 0x7a722176, 0x77d664e5, +} +var wn = [128]float32{ + 1.7290405e-09, 1.2680929e-10, 1.6897518e-10, 1.9862688e-10, + 2.2232431e-10, 2.4244937e-10, 2.601613e-10, 2.7611988e-10, + 2.9073963e-10, 3.042997e-10, 3.1699796e-10, 3.289802e-10, + 3.4035738e-10, 3.5121603e-10, 3.616251e-10, 3.7164058e-10, + 3.8130857e-10, 3.9066758e-10, 3.9975012e-10, 4.08584e-10, + 4.1719309e-10, 4.2559822e-10, 4.338176e-10, 4.418672e-10, + 4.497613e-10, 4.5751258e-10, 4.651324e-10, 4.7263105e-10, + 4.8001775e-10, 4.87301e-10, 4.944885e-10, 5.015873e-10, + 5.0860405e-10, 5.155446e-10, 5.2241467e-10, 5.2921934e-10, + 5.359635e-10, 5.426517e-10, 5.4928817e-10, 5.5587696e-10, + 5.624219e-10, 5.6892646e-10, 5.753941e-10, 5.818282e-10, + 5.882317e-10, 5.946077e-10, 6.00959e-10, 6.072884e-10, + 6.135985e-10, 6.19892e-10, 6.2617134e-10, 6.3243905e-10, + 6.386974e-10, 6.449488e-10, 6.511956e-10, 6.5744005e-10, + 6.6368433e-10, 6.699307e-10, 6.7618144e-10, 6.824387e-10, + 6.8870465e-10, 6.949815e-10, 7.012715e-10, 7.075768e-10, + 7.1389966e-10, 7.202424e-10, 7.266073e-10, 7.329966e-10, + 7.394128e-10, 7.4585826e-10, 7.5233547e-10, 7.58847e-10, + 7.653954e-10, 7.719835e-10, 7.7861395e-10, 7.852897e-10, + 7.920138e-10, 7.987892e-10, 8.0561924e-10, 8.125073e-10, + 8.194569e-10, 8.2647167e-10, 8.3355556e-10, 8.407127e-10, + 8.479473e-10, 8.55264e-10, 8.6266755e-10, 8.7016316e-10, + 8.777562e-10, 8.8545243e-10, 8.932582e-10, 9.0117996e-10, + 9.09225e-10, 9.174008e-10, 9.2571584e-10, 9.341788e-10, + 9.427997e-10, 9.515889e-10, 9.605579e-10, 9.697193e-10, + 9.790869e-10, 9.88676e-10, 9.985036e-10, 1.0085882e-09, + 1.0189509e-09, 1.0296151e-09, 1.0406069e-09, 1.0519566e-09, + 1.063698e-09, 1.0758702e-09, 1.0885183e-09, 1.1016947e-09, + 1.1154611e-09, 1.1298902e-09, 1.1450696e-09, 1.1611052e-09, + 1.1781276e-09, 1.1962995e-09, 1.2158287e-09, 1.2369856e-09, + 1.2601323e-09, 1.2857697e-09, 1.3146202e-09, 1.347784e-09, + 1.3870636e-09, 1.4357403e-09, 1.5008659e-09, 1.6030948e-09, +} +var fn = [128]float32{ + 1, 0.9635997, 0.9362827, 0.9130436, 0.89228165, 0.87324303, + 0.8555006, 0.8387836, 0.8229072, 0.8077383, 0.793177, + 0.7791461, 0.7655842, 0.7524416, 0.73967725, 0.7272569, + 0.7151515, 0.7033361, 0.69178915, 0.68049186, 0.6694277, + 0.658582, 0.6479418, 0.63749546, 0.6272325, 0.6171434, + 0.6072195, 0.5974532, 0.58783704, 0.5783647, 0.56903, + 0.5598274, 0.5507518, 0.54179835, 0.5329627, 0.52424055, + 0.5156282, 0.50712204, 0.49871865, 0.49041483, 0.48220766, + 0.4740943, 0.46607214, 0.4581387, 0.45029163, 0.44252872, + 0.43484783, 0.427247, 0.41972435, 0.41227803, 0.40490642, + 0.39760786, 0.3903808, 0.3832238, 0.37613547, 0.36911446, + 0.3621595, 0.35526937, 0.34844297, 0.34167916, 0.33497685, + 0.3283351, 0.3217529, 0.3152294, 0.30876362, 0.30235484, + 0.29600215, 0.28970486, 0.2834622, 0.2772735, 0.27113807, + 0.2650553, 0.25902456, 0.2530453, 0.24711695, 0.241239, + 0.23541094, 0.22963232, 0.2239027, 0.21822165, 0.21258877, + 0.20700371, 0.20146611, 0.19597565, 0.19053204, 0.18513499, + 0.17978427, 0.17447963, 0.1692209, 0.16400786, 0.15884037, + 0.15371831, 0.14864157, 0.14361008, 0.13862377, 0.13368265, + 0.12878671, 0.12393598, 0.119130544, 0.11437051, 0.10965602, + 0.104987256, 0.10036444, 0.095787846, 0.0912578, 0.08677467, + 0.0823389, 0.077950984, 0.073611505, 0.06932112, 0.06508058, + 0.06089077, 0.056752663, 0.0526674, 0.048636295, 0.044660863, + 0.040742867, 0.03688439, 0.033087887, 0.029356318, + 0.025693292, 0.022103304, 0.018592102, 0.015167298, + 0.011839478, 0.008624485, 0.005548995, 0.0026696292, +} diff --git a/vendor/golang.org/x/exp/rand/rand.go b/vendor/golang.org/x/exp/rand/rand.go new file mode 100644 index 000000000..ee6161bc6 --- /dev/null +++ b/vendor/golang.org/x/exp/rand/rand.go @@ -0,0 +1,372 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rand implements pseudo-random number generators. +// +// Random numbers are generated by a Source. Top-level functions, such as +// Float64 and Int, use a default shared Source that produces a deterministic +// sequence of values each time a program is run. Use the Seed function to +// initialize the default Source if different behavior is required for each run. +// The default Source, a LockedSource, is safe for concurrent use by multiple +// goroutines, but Sources created by NewSource are not. However, Sources are small +// and it is reasonable to have a separate Source for each goroutine, seeded +// differently, to avoid locking. +// +// For random numbers suitable for security-sensitive work, see the crypto/rand +// package. +package rand + +import "sync" + +// A Source represents a source of uniformly-distributed +// pseudo-random int64 values in the range [0, 1<<64). +type Source interface { + Uint64() uint64 + Seed(seed uint64) +} + +// NewSource returns a new pseudo-random Source seeded with the given value. +func NewSource(seed uint64) Source { + var rng PCGSource + rng.Seed(seed) + return &rng +} + +// A Rand is a source of random numbers. +type Rand struct { + src Source + + // readVal contains remainder of 64-bit integer used for bytes + // generation during most recent Read call. + // It is saved so next Read call can start where the previous + // one finished. + readVal uint64 + // readPos indicates the number of low-order bytes of readVal + // that are still valid. + readPos int8 +} + +// New returns a new Rand that uses random values from src +// to generate other random values. +func New(src Source) *Rand { + return &Rand{src: src} +} + +// Seed uses the provided seed value to initialize the generator to a deterministic state. +// Seed should not be called concurrently with any other Rand method. +func (r *Rand) Seed(seed uint64) { + if lk, ok := r.src.(*LockedSource); ok { + lk.seedPos(seed, &r.readPos) + return + } + + r.src.Seed(seed) + r.readPos = 0 +} + +// Uint64 returns a pseudo-random 64-bit integer as a uint64. +func (r *Rand) Uint64() uint64 { return r.src.Uint64() } + +// Int63 returns a non-negative pseudo-random 63-bit integer as an int64. +func (r *Rand) Int63() int64 { return int64(r.src.Uint64() &^ (1 << 63)) } + +// Uint32 returns a pseudo-random 32-bit value as a uint32. +func (r *Rand) Uint32() uint32 { return uint32(r.Uint64() >> 32) } + +// Int31 returns a non-negative pseudo-random 31-bit integer as an int32. +func (r *Rand) Int31() int32 { return int32(r.Uint64() >> 33) } + +// Int returns a non-negative pseudo-random int. +func (r *Rand) Int() int { + u := uint(r.Uint64()) + return int(u << 1 >> 1) // clear sign bit. +} + +const maxUint64 = (1 << 64) - 1 + +// Uint64n returns, as a uint64, a pseudo-random number in [0,n). +// It is guaranteed more uniform than taking a Source value mod n +// for any n that is not a power of 2. +func (r *Rand) Uint64n(n uint64) uint64 { + if n&(n-1) == 0 { // n is power of two, can mask + if n == 0 { + panic("invalid argument to Uint64n") + } + return r.Uint64() & (n - 1) + } + // If n does not divide v, to avoid bias we must not use + // a v that is within maxUint64%n of the top of the range. + v := r.Uint64() + if v > maxUint64-n { // Fast check. + ceiling := maxUint64 - maxUint64%n + for v >= ceiling { + v = r.Uint64() + } + } + + return v % n +} + +// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n). +// It panics if n <= 0. +func (r *Rand) Int63n(n int64) int64 { + if n <= 0 { + panic("invalid argument to Int63n") + } + return int64(r.Uint64n(uint64(n))) +} + +// Int31n returns, as an int32, a non-negative pseudo-random number in [0,n). +// It panics if n <= 0. +func (r *Rand) Int31n(n int32) int32 { + if n <= 0 { + panic("invalid argument to Int31n") + } + // TODO: Avoid some 64-bit ops to make it more efficient on 32-bit machines. + return int32(r.Uint64n(uint64(n))) +} + +// Intn returns, as an int, a non-negative pseudo-random number in [0,n). +// It panics if n <= 0. +func (r *Rand) Intn(n int) int { + if n <= 0 { + panic("invalid argument to Intn") + } + // TODO: Avoid some 64-bit ops to make it more efficient on 32-bit machines. + return int(r.Uint64n(uint64(n))) +} + +// Float64 returns, as a float64, a pseudo-random number in [0.0,1.0). +func (r *Rand) Float64() float64 { + // There is one bug in the value stream: r.Int63() may be so close + // to 1<<63 that the division rounds up to 1.0, and we've guaranteed + // that the result is always less than 1.0. + // + // We tried to fix this by mapping 1.0 back to 0.0, but since float64 + // values near 0 are much denser than near 1, mapping 1 to 0 caused + // a theoretically significant overshoot in the probability of returning 0. + // Instead of that, if we round up to 1, just try again. + // Getting 1 only happens 1/2⁵³ of the time, so most clients + // will not observe it anyway. +again: + f := float64(r.Uint64n(1<<53)) / (1 << 53) + if f == 1.0 { + goto again // resample; this branch is taken O(never) + } + return f +} + +// Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). +func (r *Rand) Float32() float32 { + // We do not want to return 1.0. + // This only happens 1/2²⁴ of the time (plus the 1/2⁵³ of the time in Float64). +again: + f := float32(r.Float64()) + if f == 1 { + goto again // resample; this branch is taken O(very rarely) + } + return f +} + +// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n). +func (r *Rand) Perm(n int) []int { + m := make([]int, n) + // In the following loop, the iteration when i=0 always swaps m[0] with m[0]. + // A change to remove this useless iteration is to assign 1 to i in the init + // statement. But Perm also effects r. Making this change will affect + // the final state of r. So this change can't be made for compatibility + // reasons for Go 1. + for i := 0; i < n; i++ { + j := r.Intn(i + 1) + m[i] = m[j] + m[j] = i + } + return m +} + +// Shuffle pseudo-randomizes the order of elements. +// n is the number of elements. Shuffle panics if n < 0. +// swap swaps the elements with indexes i and j. +func (r *Rand) Shuffle(n int, swap func(i, j int)) { + if n < 0 { + panic("invalid argument to Shuffle") + } + + // Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle + // Shuffle really ought not be called with n that doesn't fit in 32 bits. + // Not only will it take a very long time, but with 2³¹! possible permutations, + // there's no way that any PRNG can have a big enough internal state to + // generate even a minuscule percentage of the possible permutations. + // Nevertheless, the right API signature accepts an int n, so handle it as best we can. + i := n - 1 + for ; i > 1<<31-1-1; i-- { + j := int(r.Int63n(int64(i + 1))) + swap(i, j) + } + for ; i > 0; i-- { + j := int(r.Int31n(int32(i + 1))) + swap(i, j) + } +} + +// Read generates len(p) random bytes and writes them into p. It +// always returns len(p) and a nil error. +// Read should not be called concurrently with any other Rand method unless +// the underlying source is a LockedSource. +func (r *Rand) Read(p []byte) (n int, err error) { + if lk, ok := r.src.(*LockedSource); ok { + return lk.Read(p, &r.readVal, &r.readPos) + } + return read(p, r.src, &r.readVal, &r.readPos) +} + +func read(p []byte, src Source, readVal *uint64, readPos *int8) (n int, err error) { + pos := *readPos + val := *readVal + rng, _ := src.(*PCGSource) + for n = 0; n < len(p); n++ { + if pos == 0 { + if rng != nil { + val = rng.Uint64() + } else { + val = src.Uint64() + } + pos = 8 + } + p[n] = byte(val) + val >>= 8 + pos-- + } + *readPos = pos + *readVal = val + return +} + +/* + * Top-level convenience functions + */ + +var globalRand = New(&LockedSource{src: *NewSource(1).(*PCGSource)}) + +// Type assert that globalRand's source is a LockedSource whose src is a PCGSource. +var _ PCGSource = globalRand.src.(*LockedSource).src + +// Seed uses the provided seed value to initialize the default Source to a +// deterministic state. If Seed is not called, the generator behaves as +// if seeded by Seed(1). +// Seed, unlike the Rand.Seed method, is safe for concurrent use. +func Seed(seed uint64) { globalRand.Seed(seed) } + +// Int63 returns a non-negative pseudo-random 63-bit integer as an int64 +// from the default Source. +func Int63() int64 { return globalRand.Int63() } + +// Uint32 returns a pseudo-random 32-bit value as a uint32 +// from the default Source. +func Uint32() uint32 { return globalRand.Uint32() } + +// Uint64 returns a pseudo-random 64-bit value as a uint64 +// from the default Source. +func Uint64() uint64 { return globalRand.Uint64() } + +// Int31 returns a non-negative pseudo-random 31-bit integer as an int32 +// from the default Source. +func Int31() int32 { return globalRand.Int31() } + +// Int returns a non-negative pseudo-random int from the default Source. +func Int() int { return globalRand.Int() } + +// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n) +// from the default Source. +// It panics if n <= 0. +func Int63n(n int64) int64 { return globalRand.Int63n(n) } + +// Int31n returns, as an int32, a non-negative pseudo-random number in [0,n) +// from the default Source. +// It panics if n <= 0. +func Int31n(n int32) int32 { return globalRand.Int31n(n) } + +// Intn returns, as an int, a non-negative pseudo-random number in [0,n) +// from the default Source. +// It panics if n <= 0. +func Intn(n int) int { return globalRand.Intn(n) } + +// Float64 returns, as a float64, a pseudo-random number in [0.0,1.0) +// from the default Source. +func Float64() float64 { return globalRand.Float64() } + +// Float32 returns, as a float32, a pseudo-random number in [0.0,1.0) +// from the default Source. +func Float32() float32 { return globalRand.Float32() } + +// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n) +// from the default Source. +func Perm(n int) []int { return globalRand.Perm(n) } + +// Shuffle pseudo-randomizes the order of elements using the default Source. +// n is the number of elements. Shuffle panics if n < 0. +// swap swaps the elements with indexes i and j. +func Shuffle(n int, swap func(i, j int)) { globalRand.Shuffle(n, swap) } + +// Read generates len(p) random bytes from the default Source and +// writes them into p. It always returns len(p) and a nil error. +// Read, unlike the Rand.Read method, is safe for concurrent use. +func Read(p []byte) (n int, err error) { return globalRand.Read(p) } + +// NormFloat64 returns a normally distributed float64 in the range +// [-math.MaxFloat64, +math.MaxFloat64] with +// standard normal distribution (mean = 0, stddev = 1) +// from the default Source. +// To produce a different normal distribution, callers can +// adjust the output using: +// +// sample = NormFloat64() * desiredStdDev + desiredMean +func NormFloat64() float64 { return globalRand.NormFloat64() } + +// ExpFloat64 returns an exponentially distributed float64 in the range +// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter +// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source. +// To produce a distribution with a different rate parameter, +// callers can adjust the output using: +// +// sample = ExpFloat64() / desiredRateParameter +func ExpFloat64() float64 { return globalRand.ExpFloat64() } + +// LockedSource is an implementation of Source that is concurrency-safe. +// A Rand using a LockedSource is safe for concurrent use. +// +// The zero value of LockedSource is valid, but should be seeded before use. +type LockedSource struct { + lk sync.Mutex + src PCGSource +} + +func (s *LockedSource) Uint64() (n uint64) { + s.lk.Lock() + n = s.src.Uint64() + s.lk.Unlock() + return +} + +func (s *LockedSource) Seed(seed uint64) { + s.lk.Lock() + s.src.Seed(seed) + s.lk.Unlock() +} + +// seedPos implements Seed for a LockedSource without a race condiiton. +func (s *LockedSource) seedPos(seed uint64, readPos *int8) { + s.lk.Lock() + s.src.Seed(seed) + *readPos = 0 + s.lk.Unlock() +} + +// Read implements Read for a LockedSource. +func (s *LockedSource) Read(p []byte, readVal *uint64, readPos *int8) (n int, err error) { + s.lk.Lock() + n, err = read(p, &s.src, readVal, readPos) + s.lk.Unlock() + return +} diff --git a/vendor/golang.org/x/exp/rand/rng.go b/vendor/golang.org/x/exp/rand/rng.go new file mode 100644 index 000000000..9b79108c7 --- /dev/null +++ b/vendor/golang.org/x/exp/rand/rng.go @@ -0,0 +1,91 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rand + +import ( + "encoding/binary" + "io" + "math/bits" +) + +// PCGSource is an implementation of a 64-bit permuted congruential +// generator as defined in +// +// PCG: A Family of Simple Fast Space-Efficient Statistically Good +// Algorithms for Random Number Generation +// Melissa E. O’Neill, Harvey Mudd College +// http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf +// +// The generator here is the congruential generator PCG XSL RR 128/64 (LCG) +// as found in the software available at http://www.pcg-random.org/. +// It has period 2^128 with 128 bits of state, producing 64-bit values. +// Is state is represented by two uint64 words. +type PCGSource struct { + low uint64 + high uint64 +} + +const ( + maxUint32 = (1 << 32) - 1 + + multiplier = 47026247687942121848144207491837523525 + mulHigh = multiplier >> 64 + mulLow = multiplier & maxUint64 + + increment = 117397592171526113268558934119004209487 + incHigh = increment >> 64 + incLow = increment & maxUint64 + + // TODO: Use these? + initializer = 245720598905631564143578724636268694099 + initHigh = initializer >> 64 + initLow = initializer & maxUint64 +) + +// Seed uses the provided seed value to initialize the generator to a deterministic state. +func (pcg *PCGSource) Seed(seed uint64) { + pcg.low = seed + pcg.high = seed // TODO: What is right? +} + +// Uint64 returns a pseudo-random 64-bit unsigned integer as a uint64. +func (pcg *PCGSource) Uint64() uint64 { + pcg.multiply() + pcg.add() + // XOR high and low 64 bits together and rotate right by high 6 bits of state. + return bits.RotateLeft64(pcg.high^pcg.low, -int(pcg.high>>58)) +} + +func (pcg *PCGSource) add() { + var carry uint64 + pcg.low, carry = bits.Add64(pcg.low, incLow, 0) + pcg.high, _ = bits.Add64(pcg.high, incHigh, carry) +} + +func (pcg *PCGSource) multiply() { + hi, lo := bits.Mul64(pcg.low, mulLow) + hi += pcg.high * mulLow + hi += pcg.low * mulHigh + pcg.low = lo + pcg.high = hi +} + +// MarshalBinary returns the binary representation of the current state of the generator. +func (pcg *PCGSource) MarshalBinary() ([]byte, error) { + var buf [16]byte + binary.BigEndian.PutUint64(buf[:8], pcg.high) + binary.BigEndian.PutUint64(buf[8:], pcg.low) + return buf[:], nil +} + +// UnmarshalBinary sets the state of the generator to the state represented in data. +func (pcg *PCGSource) UnmarshalBinary(data []byte) error { + if len(data) < 16 { + return io.ErrUnexpectedEOF + } + pcg.low = binary.BigEndian.Uint64(data[8:]) + pcg.high = binary.BigEndian.Uint64(data[:8]) + return nil +} diff --git a/vendor/golang.org/x/exp/rand/zipf.go b/vendor/golang.org/x/exp/rand/zipf.go new file mode 100644 index 000000000..f04c814eb --- /dev/null +++ b/vendor/golang.org/x/exp/rand/zipf.go @@ -0,0 +1,77 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// W.Hormann, G.Derflinger: +// "Rejection-Inversion to Generate Variates +// from Monotone Discrete Distributions" +// http://eeyore.wu-wien.ac.at/papers/96-04-04.wh-der.ps.gz + +package rand + +import "math" + +// A Zipf generates Zipf distributed variates. +type Zipf struct { + r *Rand + imax float64 + v float64 + q float64 + s float64 + oneminusQ float64 + oneminusQinv float64 + hxm float64 + hx0minusHxm float64 +} + +func (z *Zipf) h(x float64) float64 { + return math.Exp(z.oneminusQ*math.Log(z.v+x)) * z.oneminusQinv +} + +func (z *Zipf) hinv(x float64) float64 { + return math.Exp(z.oneminusQinv*math.Log(z.oneminusQ*x)) - z.v +} + +// NewZipf returns a Zipf variate generator. +// The generator generates values k ∈ [0, imax] +// such that P(k) is proportional to (v + k) ** (-s). +// Requirements: s > 1 and v >= 1. +func NewZipf(r *Rand, s float64, v float64, imax uint64) *Zipf { + z := new(Zipf) + if s <= 1.0 || v < 1 { + return nil + } + z.r = r + z.imax = float64(imax) + z.v = v + z.q = s + z.oneminusQ = 1.0 - z.q + z.oneminusQinv = 1.0 / z.oneminusQ + z.hxm = z.h(z.imax + 0.5) + z.hx0minusHxm = z.h(0.5) - math.Exp(math.Log(z.v)*(-z.q)) - z.hxm + z.s = 1 - z.hinv(z.h(1.5)-math.Exp(-z.q*math.Log(z.v+1.0))) + return z +} + +// Uint64 returns a value drawn from the Zipf distribution described +// by the Zipf object. +func (z *Zipf) Uint64() uint64 { + if z == nil { + panic("rand: nil Zipf") + } + k := 0.0 + + for { + r := z.r.Float64() // r on [0,1] + ur := z.hxm + r*z.hx0minusHxm + x := z.hinv(ur) + k = math.Floor(x + 0.5) + if k-x <= z.s { + break + } + if ur >= z.h(k+0.5)-math.Exp(-math.Log(k+z.v)*z.q) { + break + } + } + return uint64(k) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index f66ff5b94..d062e1f5d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -216,10 +216,6 @@ github.com/golang/glog # github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da ## explicit github.com/golang/groupcache/lru -# github.com/golang/mock v1.6.0 -## explicit; go 1.11 -github.com/golang/mock/mockgen -github.com/golang/mock/mockgen/model # github.com/golang/protobuf v1.5.3 ## explicit; go 1.9 github.com/golang/protobuf/proto @@ -418,11 +414,8 @@ github.com/pmezard/go-difflib/difflib github.com/pterm/pterm github.com/pterm/pterm/internal github.com/pterm/pterm/putils -# github.com/quic-go/qtls-go1-20 v0.3.3 -## explicit; go 1.20 -github.com/quic-go/qtls-go1-20 -# github.com/quic-go/quic-go v0.38.1 -## explicit; go 1.20 +# github.com/quic-go/quic-go v0.42.0 +## explicit; go 1.21 github.com/quic-go/quic-go github.com/quic-go/quic-go/internal/ackhandler github.com/quic-go/quic-go/internal/congestion @@ -643,6 +636,10 @@ go.opencensus.io/internal go.opencensus.io/trace go.opencensus.io/trace/internal go.opencensus.io/trace/tracestate +# go.uber.org/mock v0.4.0 +## explicit; go 1.20 +go.uber.org/mock/mockgen +go.uber.org/mock/mockgen/model # golang.org/x/arch v0.4.0 ## explicit; go 1.17 golang.org/x/arch/x86/x86asm @@ -654,8 +651,6 @@ golang.org/x/crypto/blowfish golang.org/x/crypto/cast5 golang.org/x/crypto/chacha20 golang.org/x/crypto/chacha20poly1305 -golang.org/x/crypto/cryptobyte -golang.org/x/crypto/cryptobyte/asn1 golang.org/x/crypto/curve25519 golang.org/x/crypto/curve25519/internal/field golang.org/x/crypto/hkdf @@ -671,7 +666,7 @@ golang.org/x/crypto/twofish golang.org/x/crypto/xtea # golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 ## explicit; go 1.20 -golang.org/x/exp/constraints +golang.org/x/exp/rand # golang.org/x/mod v0.12.0 ## explicit; go 1.17 golang.org/x/mod/internal/lazyregexp