diff --git a/.gitignore b/.gitignore index b454729d5b2..ce05268411c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,96 @@ fuzzing/*/suppressions fuzzing/*/corpus/ gomock_reflect_*/ + +# Created by https://www.toptal.com/developers/gitignore/api/goland+all +# Edit at https://www.toptal.com/developers/gitignore?templates=goland+all + +### GoLand+all ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### GoLand+all Patch ### +# Ignore everything but code style settings and run configurations +# that are supposed to be shared within teams. + +.idea/* + +!.idea/codeStyles +!.idea/runConfigurations + +# End of https://www.toptal.com/developers/gitignore/api/goland+all \ No newline at end of file diff --git a/congestion/interface.go b/congestion/interface.go new file mode 100644 index 00000000000..3e41a4de692 --- /dev/null +++ b/congestion/interface.go @@ -0,0 +1,65 @@ +package congestion + +import ( + "time" + + "github.com/quic-go/quic-go/internal/protocol" +) + +type ( + ByteCount protocol.ByteCount + PacketNumber protocol.PacketNumber +) + +// Expose some constants from protocol that congestion control algorithms may need. +const ( + InitialPacketSizeIPv4 = protocol.InitialPacketSize + InitialPacketSizeIPv6 = protocol.InitialPacketSize + MinPacingDelay = protocol.MinPacingDelay + MaxPacketBufferSize = protocol.MaxPacketBufferSize + MinInitialPacketSize = protocol.MinInitialPacketSize + MaxCongestionWindowPackets = protocol.MaxCongestionWindowPackets + PacketsPerConnectionID = protocol.PacketsPerConnectionID +) + +type AckedPacketInfo struct { + PacketNumber PacketNumber + BytesAcked ByteCount + ReceivedTime time.Time +} + +type LostPacketInfo struct { + PacketNumber PacketNumber + BytesLost ByteCount +} + +type CongestionControl interface { + SetRTTStatsProvider(provider RTTStatsProvider) + TimeUntilSend(bytesInFlight ByteCount) time.Time + HasPacingBudget(now time.Time) bool + OnPacketSent(sentTime time.Time, bytesInFlight ByteCount, packetNumber PacketNumber, bytes ByteCount, isRetransmittable bool) + CanSend(bytesInFlight ByteCount) bool + MaybeExitSlowStart() + OnPacketAcked(number PacketNumber, ackedBytes ByteCount, priorInFlight ByteCount, eventTime time.Time) + OnCongestionEvent(number PacketNumber, lostBytes ByteCount, priorInFlight ByteCount) + OnCongestionEventEx(priorInFlight ByteCount, eventTime time.Time, ackedPackets []AckedPacketInfo, lostPackets []LostPacketInfo) + OnRetransmissionTimeout(packetsRetransmitted bool) + SetMaxDatagramSize(size ByteCount) + InSlowStart() bool + InRecovery() bool + GetCongestionWindow() ByteCount +} + +type RTTStatsProvider interface { + MinRTT() time.Duration + LatestRTT() time.Duration + SmoothedRTT() time.Duration + MeanDeviation() time.Duration + MaxAckDelay() time.Duration + PTO(includeMaxAckDelay bool) time.Duration + UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) + SetMaxAckDelay(mad time.Duration) + SetInitialRTT(t time.Duration) + OnConnectionMigration() + ExpireSmoothedMetrics() +} diff --git a/connection.go b/connection.go index 1411a77b739..cd10fd7ce99 100644 --- a/connection.go +++ b/connection.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/quic-go/quic-go/congestion" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/flowcontrol" "github.com/quic-go/quic-go/internal/handshake" @@ -871,6 +872,13 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { } } + // Hysteria connection migration + // Set remote address to the address of the last received valid packet + if s.perspective == protocol.PerspectiveServer && processed { + // Connection migration + s.conn.SetRemoteAddr(rp.remoteAddr) + } + p.buffer.MaybeRelease() return processed } @@ -2286,7 +2294,9 @@ func (s *connection) SendDatagram(p []byte) error { protocol.ByteCount(s.maxPayloadSizeEstimate.Load()), ) if protocol.ByteCount(len(p)) > maxDataLen { - return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)} + return &DatagramTooLargeError{ + MaxDataLen: int64(maxDataLen), + } } f.Data = make([]byte, len(p)) copy(f.Data, p) @@ -2331,3 +2341,7 @@ func (s *connection) NextConnection(ctx context.Context) (Connection, error) { func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount { return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */ } + +func (s *connection) SetCongestionControl(cc congestion.CongestionControl) { + s.sentPacketHandler.SetCongestionControl(cc) +} diff --git a/connection_test.go b/connection_test.go index 6ece26dc7bb..af61a14c810 100644 --- a/connection_test.go +++ b/connection_test.go @@ -2506,8 +2506,8 @@ var _ = Describe("Connection", func() { Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&DatagramTooLargeError{})) derr := err.(*DatagramTooLargeError) - Expect(derr.MaxDatagramPayloadSize).To(BeNumerically("<", 1000)) - Expect(conn.SendDatagram(make([]byte, derr.MaxDatagramPayloadSize))).To(Succeed()) + Expect(derr.MaxDataLen).To(BeNumerically("<", 1000)) + Expect(conn.SendDatagram(make([]byte, derr.MaxDataLen))).To(Succeed()) }) It("receives datagrams", func() { diff --git a/errors.go b/errors.go index 3fe1e0a9024..8275246f13b 100644 --- a/errors.go +++ b/errors.go @@ -64,7 +64,7 @@ func (e *StreamError) Error() string { // DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent. type DatagramTooLargeError struct { - MaxDatagramPayloadSize int64 + MaxDataLen int64 } func (e *DatagramTooLargeError) Is(target error) bool { diff --git a/frame_sorter.go b/frame_sorter.go index bee0abadb53..1637a3ee49f 100644 --- a/frame_sorter.go +++ b/frame_sorter.go @@ -2,10 +2,10 @@ package quic import ( "errors" - "sync" "github.com/quic-go/quic-go/internal/protocol" - list "github.com/quic-go/quic-go/internal/utils/linkedlist" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/utils/tree" ) // byteInterval is an interval from one ByteCount to the other @@ -14,12 +14,6 @@ type byteInterval struct { End protocol.ByteCount } -var byteIntervalElementPool sync.Pool - -func init() { - byteIntervalElementPool = *list.NewPool[byteInterval]() -} - type frameSorterEntry struct { Data []byte DoneCb func() @@ -28,17 +22,17 @@ type frameSorterEntry struct { type frameSorter struct { queue map[protocol.ByteCount]frameSorterEntry readPos protocol.ByteCount - gaps *list.List[byteInterval] + gapTree *tree.Btree[utils.ByteInterval] } var errDuplicateStreamData = errors.New("duplicate stream data") func newFrameSorter() *frameSorter { s := frameSorter{ - gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool), - queue: make(map[protocol.ByteCount]frameSorterEntry), + gapTree: tree.New[utils.ByteInterval](), + queue: make(map[protocol.ByteCount]frameSorterEntry), } - s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount}) + s.gapTree.Insert(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } @@ -60,25 +54,30 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() start := offset end := offset + protocol.ByteCount(len(data)) + covInterval := utils.ByteInterval{Start: start, End: end} - if end <= s.gaps.Front().Value.Start { + gaps := s.gapTree.Match(covInterval) + + if len(gaps) == 0 { + // No overlap with any existing gap return errDuplicateStreamData } - startGap, startsInGap := s.findStartGap(start) - endGap, endsInGap := s.findEndGap(startGap, end) - - startGapEqualsEndGap := startGap == endGap + startGap := gaps[0] + endGap := gaps[len(gaps)-1] + startGapEqualsEndGap := len(gaps) == 1 - if (startGapEqualsEndGap && end <= startGap.Value.Start) || - (!startGapEqualsEndGap && startGap.Value.End >= endGap.Value.Start && end <= startGap.Value.Start) { + if startGapEqualsEndGap && end <= startGap.Start { return errDuplicateStreamData } - startGapNext := startGap.Next() - startGapEnd := startGap.Value.End // save it, in case startGap is modified - endGapStart := endGap.Value.Start // save it, in case endGap is modified - endGapEnd := endGap.Value.End // save it, in case endGap is modified + startsInGap := covInterval.Start >= startGap.Start && covInterval.Start <= startGap.End + endsInGap := covInterval.End >= endGap.Start && covInterval.End < endGap.End + + startGapEnd := startGap.End // save it, in case startGap is modified + endGapStart := endGap.Start // save it, in case endGap is modified + endGapEnd := endGap.End // save it, in case endGap is modified + var adjustedStartGapEnd bool var wasCut bool @@ -113,29 +112,36 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() if !startsInGap && !hasReplacedAtLeastOne { // cut the frame, such that it starts at the start of the gap - data = data[startGap.Value.Start-start:] - start = startGap.Value.Start + data = data[startGap.Start-start:] + start = startGap.Start wasCut = true } - if start <= startGap.Value.Start { - if end >= startGap.Value.End { + if start <= startGap.Start { + if end >= startGap.End { // The frame covers the whole startGap. Delete the gap. - s.gaps.Remove(startGap) + s.gapTree.Delete(startGap) } else { - startGap.Value.Start = end + s.gapTree.Delete(startGap) + startGap.Start = end + // Re-insert the gap, but with the new start. + s.gapTree.Insert(startGap) } } else if !hasReplacedAtLeastOne { - startGap.Value.End = start + s.gapTree.Delete(startGap) + startGap.End = start + // Re-insert the gap, but with the new end. + s.gapTree.Insert(startGap) adjustedStartGapEnd = true } if !startGapEqualsEndGap { s.deleteConsecutive(startGapEnd) - var nextGap *list.Element[byteInterval] - for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap { - nextGap = gap.Next() - s.deleteConsecutive(gap.Value.End) - s.gaps.Remove(gap) + for _, g := range gaps[1:] { + if g.End >= endGapStart { + break + } + s.deleteConsecutive(g.End) + s.gapTree.Delete(g) } } @@ -148,14 +154,17 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() if end == endGapEnd { if !startGapEqualsEndGap { // The frame covers the whole endGap. Delete the gap. - s.gaps.Remove(endGap) + s.gapTree.Delete(endGap) } } else { if startGapEqualsEndGap && adjustedStartGapEnd { // The frame split the existing gap into two. - s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap) + s.gapTree.Insert(utils.ByteInterval{Start: end, End: startGapEnd}) } else if !startGapEqualsEndGap { - endGap.Value.Start = end + s.gapTree.Delete(endGap) + endGap.Start = end + // Re-insert the gap, but with the new start. + s.gapTree.Insert(endGap) } } @@ -169,7 +178,7 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() } } - if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps { + if s.gapTree.Len() > protocol.MaxStreamFrameSorterGaps { return errors.New("too many gaps in received data") } @@ -177,30 +186,6 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() return nil } -func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) { - for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { - if offset >= gap.Value.Start && offset <= gap.Value.End { - return gap, true - } - if offset < gap.Value.Start { - return gap, false - } - } - panic("no gap found") -} - -func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) { - for gap := startGap; gap != nil; gap = gap.Next() { - if offset >= gap.Value.Start && offset < gap.Value.End { - return gap, true - } - if offset < gap.Value.Start { - return gap.Prev(), false - } - } - panic("no gap found") -} - // deleteConsecutive deletes consecutive frames from the queue, starting at pos func (s *frameSorter) deleteConsecutive(pos protocol.ByteCount) { for { @@ -225,9 +210,6 @@ func (s *frameSorter) Pop() (protocol.ByteCount, []byte, func()) { delete(s.queue, s.readPos) offset := s.readPos s.readPos += protocol.ByteCount(len(entry.Data)) - if s.gaps.Front().Value.End <= s.readPos { - panic("frame sorter BUG: read position higher than a gap") - } return offset, entry.Data, entry.DoneCb } diff --git a/frame_sorter_test.go b/frame_sorter_test.go index 9a684c91c1e..46fb7078277 100644 --- a/frame_sorter_test.go +++ b/frame_sorter_test.go @@ -9,6 +9,8 @@ import ( "golang.org/x/exp/rand" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/utils/tree" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -18,18 +20,25 @@ var _ = Describe("frame sorter", func() { var s *frameSorter checkGaps := func(expectedGaps []byteInterval) { - if s.gaps.Len() != len(expectedGaps) { + if s.gapTree.Len() != len(expectedGaps) { fmt.Println("Gaps:") - for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { - fmt.Printf("\t%d - %d\n", gap.Value.Start, gap.Value.End) - } - ExpectWithOffset(1, s.gaps.Len()).To(Equal(len(expectedGaps))) + s.gapTree.Ascend(func(n *tree.Node[utils.ByteInterval], i int) bool { + gap := n.Value + fmt.Printf("\t%d - %d\n", gap.Start, gap.End) + return true + }) + ExpectWithOffset(1, s.gapTree.Len()).To(Equal(len(expectedGaps))) } var i int - for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { - ExpectWithOffset(1, gap.Value).To(Equal(expectedGaps[i])) + s.gapTree.Ascend(func(n *tree.Node[utils.ByteInterval], _ int) bool { + gap := n.Value + ExpectWithOffset(1, gap).To(Equal(utils.ByteInterval{ + Start: expectedGaps[i].Start, + End: expectedGaps[i].End, + })) i++ - } + return true + }) } type callbackTracker struct { @@ -1378,7 +1387,7 @@ var _ = Describe("frame sorter", func() { for i := 0; i < protocol.MaxStreamFrameSorterGaps; i++ { Expect(s.Push([]byte("foobar"), protocol.ByteCount(i*7), nil)).To(Succeed()) } - Expect(s.gaps.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps)) + Expect(s.gapTree.Len()).To(Equal(protocol.MaxStreamFrameSorterGaps)) err := s.Push([]byte("foobar"), protocol.ByteCount(protocol.MaxStreamFrameSorterGaps*7)+100, nil) Expect(err).To(MatchError("too many gaps in received data")) }) diff --git a/go.mod b/go.mod index 17db49c19af..f1697b797f0 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/francoispqt/gojay v1.2.13 + github.com/golang/mock v1.2.0 github.com/onsi/ginkgo/v2 v2.9.5 github.com/onsi/gomega v1.27.6 github.com/prometheus/client_golang v1.19.1 diff --git a/go.sum b/go.sum index 25fe9d114fe..e03172af58a 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,7 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 1b4f0efdeaa..3dc8cb56fed 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -64,7 +64,7 @@ var _ = Describe("Datagram test", func() { maxDatagramMessageSize := f.MaxDataLen(maxDatagramSize, conn.ConnectionState().Version) b := make([]byte, maxDatagramMessageSize+1) Expect(conn.SendDatagram(b)).To(MatchError(&quic.DatagramTooLargeError{ - MaxDatagramPayloadSize: int64(maxDatagramMessageSize), + MaxDataLen: int64(maxDatagramMessageSize), })) wg.Wait() } diff --git a/integrationtests/self/mtu_test.go b/integrationtests/self/mtu_test.go index 744d9fa047b..5c8d1c3b17c 100644 --- a/integrationtests/self/mtu_test.go +++ b/integrationtests/self/mtu_test.go @@ -110,14 +110,14 @@ var _ = Describe("DPLPMTUD", func() { }() err = conn.SendDatagram(make([]byte, 2000)) Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{})) - initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize + initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDataLen _, err = str.Write(PRDataLong) Expect(err).ToNot(HaveOccurred()) str.Close() Eventually(done, 20*time.Second).Should(BeClosed()) err = conn.SendDatagram(make([]byte, 2000)) Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{})) - finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize + finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDataLen mx.Lock() defer mx.Unlock() diff --git a/interface.go b/interface.go index 2071b596f74..2ab5f4f97bd 100644 --- a/interface.go +++ b/interface.go @@ -8,6 +8,7 @@ import ( "net" "time" + "github.com/quic-go/quic-go/congestion" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/logging" @@ -206,6 +207,9 @@ type Connection interface { SendDatagram(payload []byte) error // ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221. ReceiveDatagram(context.Context) ([]byte, error) + + // Replace the current congestion control algorithm with a new one. + SetCongestionControl(congestion.CongestionControl) } // An EarlyConnection is a connection that is handshaking. diff --git a/internal/ackhandler/cc_adapter.go b/internal/ackhandler/cc_adapter.go new file mode 100644 index 00000000000..659ab83495a --- /dev/null +++ b/internal/ackhandler/cc_adapter.go @@ -0,0 +1,70 @@ +package ackhandler + +import ( + "time" + + "github.com/quic-go/quic-go/congestion" + cgInternal "github.com/quic-go/quic-go/internal/congestion" + "github.com/quic-go/quic-go/internal/protocol" +) + +var ( + _ cgInternal.SendAlgorithmEx = &ccAdapter{} + _ cgInternal.SendAlgorithmWithDebugInfos = &ccAdapter{} +) + +type ccAdapter struct { + CC congestion.CongestionControl +} + +func (a *ccAdapter) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time { + return a.CC.TimeUntilSend(congestion.ByteCount(bytesInFlight)) +} + +func (a *ccAdapter) HasPacingBudget(now time.Time) bool { + return a.CC.HasPacingBudget(now) +} + +func (a *ccAdapter) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) { + a.CC.OnPacketSent(sentTime, congestion.ByteCount(bytesInFlight), congestion.PacketNumber(packetNumber), congestion.ByteCount(bytes), isRetransmittable) +} + +func (a *ccAdapter) CanSend(bytesInFlight protocol.ByteCount) bool { + return a.CC.CanSend(congestion.ByteCount(bytesInFlight)) +} + +func (a *ccAdapter) MaybeExitSlowStart() { + a.CC.MaybeExitSlowStart() +} + +func (a *ccAdapter) OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) { + a.CC.OnPacketAcked(congestion.PacketNumber(number), congestion.ByteCount(ackedBytes), congestion.ByteCount(priorInFlight), eventTime) +} + +func (a *ccAdapter) OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) { + a.CC.OnCongestionEvent(congestion.PacketNumber(number), congestion.ByteCount(lostBytes), congestion.ByteCount(priorInFlight)) +} + +func (a *ccAdapter) OnCongestionEventEx(priorInFlight protocol.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + a.CC.OnCongestionEventEx(congestion.ByteCount(priorInFlight), eventTime, ackedPackets, lostPackets) +} + +func (a *ccAdapter) OnRetransmissionTimeout(packetsRetransmitted bool) { + a.CC.OnRetransmissionTimeout(packetsRetransmitted) +} + +func (a *ccAdapter) SetMaxDatagramSize(size protocol.ByteCount) { + a.CC.SetMaxDatagramSize(congestion.ByteCount(size)) +} + +func (a *ccAdapter) InSlowStart() bool { + return a.CC.InSlowStart() +} + +func (a *ccAdapter) InRecovery() bool { + return a.CC.InRecovery() +} + +func (a *ccAdapter) GetCongestionWindow() protocol.ByteCount { + return protocol.ByteCount(a.CC.GetCongestionWindow()) +} diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index ba8cbbdae02..ffbb41514ce 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -3,6 +3,7 @@ package ackhandler import ( "time" + "github.com/quic-go/quic-go/congestion" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) @@ -35,6 +36,8 @@ type SentPacketHandler interface { GetLossDetectionTimeout() time.Time OnLossDetectionTimeout() error + + SetCongestionControl(congestion.CongestionControl) } type sentPacketTracker interface { diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b84f0dcbbc8..8204965c7c5 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -3,8 +3,10 @@ package ackhandler import ( "errors" "fmt" + "sync" "time" + congestionExt "github.com/quic-go/quic-go/congestion" "github.com/quic-go/quic-go/internal/congestion" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" @@ -75,12 +77,15 @@ type sentPacketHandler struct { // Only applies to the application-data packet number space. lowestNotConfirmedAcked protocol.PacketNumber - ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets + ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets + ackedPacketsInfo []congestionExt.AckedPacketInfo + lostPacketsInfo []congestionExt.LostPacketInfo bytesInFlight protocol.ByteCount - congestion congestion.SendAlgorithmWithDebugInfos - rttStats *utils.RTTStats + congestion congestion.SendAlgorithmWithDebugInfos + congestionMutex sync.RWMutex + rttStats *utils.RTTStats // The number of times a PTO has been sent without receiving an ack. ptoCount uint32 @@ -260,7 +265,9 @@ func (h *sentPacketHandler) SentPacket( h.numProbesToSend-- } } - h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) + + cc := h.getCongestionControl() + cc.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { h.ecnTracker.SentPacket(pn, ecn) @@ -287,7 +294,7 @@ func (h *sentPacketHandler) SentPacket( pnSpace.history.SentAckElicitingPacket(p) if h.tracer != nil && h.tracer.UpdatedMetrics != nil { - h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + h.tracer.UpdatedMetrics(h.rttStats, cc.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } h.setLossDetectionTimer() } @@ -330,6 +337,9 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En if err != nil || len(ackedPackets) == 0 { return false, err } + + cc := h.getCongestionControl() + // update the RTT, if the largest acked is newly acknowledged if len(ackedPackets) > 0 { if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() { @@ -342,7 +352,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } - h.congestion.MaybeExitSlowStart() + cc.MaybeExitSlowStart() } } @@ -350,7 +360,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked { congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE)) if congested { - h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight) + cc.OnCongestionEvent(largestAcked, 0, priorInFlight) } } @@ -359,10 +369,15 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err } + h.ackedPacketsInfo = h.ackedPacketsInfo[:0] var acked1RTTPacket bool for _, p := range ackedPackets { if p.includedInBytesInFlight && !p.declaredLost { - h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) + cc.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime) + h.ackedPacketsInfo = append(h.ackedPacketsInfo, congestionExt.AckedPacketInfo{ + PacketNumber: congestionExt.PacketNumber(p.PacketNumber), + BytesAcked: congestionExt.ByteCount(p.Length), + }) } if p.EncryptionLevel == protocol.Encryption1RTT { acked1RTTPacket = true @@ -370,9 +385,17 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.removeFromBytesInFlight(p) putPacket(p) } + + if cex, ok := h.congestion.(congestion.SendAlgorithmEx); ok && + (len(h.ackedPacketsInfo) != 0 || len(h.lostPacketsInfo) != 0) { + cex.OnCongestionEventEx(priorInFlight, rcvTime, h.ackedPacketsInfo, h.lostPacketsInfo) + } + // After this point, we must not use ackedPackets any longer! // We've already returned the buffers. - ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side. + ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side. + h.ackedPacketsInfo = nil //nolint:ineffassign // This is just to be on the safe side. + h.lostPacketsInfo = nil //nolint:ineffassign // This is just to be on the safe side. // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { @@ -384,7 +407,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.numProbesToSend = 0 if h.tracer != nil && h.tracer.UpdatedMetrics != nil { - h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + h.tracer.UpdatedMetrics(h.rttStats, cc.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } h.setLossDetectionTimer() @@ -604,6 +627,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { } func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error { + h.lostPacketsInfo = h.lostPacketsInfo[:0] pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} @@ -616,8 +640,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E // Packets sent before this time are deemed lost. lostSendTime := now.Add(-lossDelay) + cc := h.getCongestionControl() + priorInFlight := h.bytesInFlight - return pnSpace.history.Iterate(func(p *packet) (bool, error) { + err := pnSpace.history.Iterate(func(p *packet) (bool, error) { if p.PacketNumber > pnSpace.largestAcked { return false, nil } @@ -658,8 +684,12 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) if !p.IsPathMTUProbePacket { - h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight) + cc.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight) } + h.lostPacketsInfo = append(h.lostPacketsInfo, congestionExt.LostPacketInfo{ + PacketNumber: congestionExt.PacketNumber(p.PacketNumber), + BytesLost: congestionExt.ByteCount(p.Length), + }) if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil { h.ecnTracker.LostPacket(p.PacketNumber) } @@ -667,10 +697,12 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E } return true, nil }) + return err } func (h *sentPacketHandler) OnLossDetectionTimeout() error { defer h.setLossDetectionTimer() + priorInFlight := h.bytesInFlight earliestLossTime, encLevel := h.getLossTimeAndSpace() if !earliestLossTime.IsZero() { if h.logger.Debug() { @@ -680,7 +712,14 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection - return h.detectLostPackets(time.Now(), encLevel) + err := h.detectLostPackets(time.Now(), encLevel) + + if cex, ok := h.congestion.(congestion.SendAlgorithmEx); ok && + len(h.lostPacketsInfo) != 0 { + cex.OnCongestionEventEx(priorInFlight, time.Now(), nil, h.lostPacketsInfo) + } + + return err } // PTO @@ -799,9 +838,10 @@ func (h *sentPacketHandler) SendMode(now time.Time) SendMode { return h.ptoMode } // Only send ACKs if we're congestion limited. - if !h.congestion.CanSend(h.bytesInFlight) { + cc := h.getCongestionControl() + if !cc.CanSend(h.bytesInFlight) { if h.logger.Debug() { - h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow()) + h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cc.GetCongestionWindow()) } return SendAck } @@ -811,18 +851,18 @@ func (h *sentPacketHandler) SendMode(now time.Time) SendMode { } return SendAck } - if !h.congestion.HasPacingBudget(now) { + if !cc.HasPacingBudget(now) { return SendPacingLimited } return SendAny } func (h *sentPacketHandler) TimeUntilSend() time.Time { - return h.congestion.TimeUntilSend(h.bytesInFlight) + return h.getCongestionControl().TimeUntilSend(h.bytesInFlight) } func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { - h.congestion.SetMaxDatagramSize(s) + h.getCongestionControl().SetMaxDatagramSize(s) } func (h *sentPacketHandler) isAmplificationLimited() bool { @@ -895,7 +935,8 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } if h.tracer != nil && h.tracer.UpdatedMetrics != nil { - h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) + cc := h.getCongestionControl() + h.tracer.UpdatedMetrics(h.rttStats, cc.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false) @@ -926,3 +967,17 @@ func (h *sentPacketHandler) SetHandshakeConfirmed() { // Make sure the timer is armed now, if necessary. h.setLossDetectionTimer() } + +func (h *sentPacketHandler) getCongestionControl() congestion.SendAlgorithmWithDebugInfos { + h.congestionMutex.RLock() + cc := h.congestion + h.congestionMutex.RUnlock() + return cc +} + +func (h *sentPacketHandler) SetCongestionControl(cc congestionExt.CongestionControl) { + h.congestionMutex.Lock() + cc.SetRTTStatsProvider(h.rttStats) + h.congestion = &ccAdapter{cc} + h.congestionMutex.Unlock() +} diff --git a/internal/congestion/interface.go b/internal/congestion/interface.go index 881f453b69a..54fcbbea35b 100644 --- a/internal/congestion/interface.go +++ b/internal/congestion/interface.go @@ -3,6 +3,7 @@ package congestion import ( "time" + "github.com/quic-go/quic-go/congestion" "github.com/quic-go/quic-go/internal/protocol" ) @@ -19,6 +20,11 @@ type SendAlgorithm interface { SetMaxDatagramSize(protocol.ByteCount) } +type SendAlgorithmEx interface { + SendAlgorithm + OnCongestionEventEx(priorInFlight protocol.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) +} + // A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos type SendAlgorithmWithDebugInfos interface { SendAlgorithm diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 9efeb9be7a3..b5796c4102a 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -13,6 +13,7 @@ import ( reflect "reflect" time "time" + congestion "github.com/quic-go/quic-go/congestion" ackhandler "github.com/quic-go/quic-go/internal/ackhandler" protocol "github.com/quic-go/quic-go/internal/protocol" wire "github.com/quic-go/quic-go/internal/wire" @@ -494,6 +495,42 @@ func (c *MockSentPacketHandlerSentPacketCall) DoAndReturn(f func(time.Time, prot return c } +// SetCongestionControl mocks base method. +func (m *MockSentPacketHandler) SetCongestionControl(arg0 congestion.CongestionControl) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCongestionControl", arg0) +} + +// SetCongestionControl indicates an expected call of SetCongestionControl. +func (mr *MockSentPacketHandlerMockRecorder) SetCongestionControl(arg0 any) *MockSentPacketHandlerSetCongestionControlCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCongestionControl", reflect.TypeOf((*MockSentPacketHandler)(nil).SetCongestionControl), arg0) + return &MockSentPacketHandlerSetCongestionControlCall{Call: call} +} + +// MockSentPacketHandlerSetCongestionControlCall wrap *gomock.Call +type MockSentPacketHandlerSetCongestionControlCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerSetCongestionControlCall) Return() *MockSentPacketHandlerSetCongestionControlCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerSetCongestionControlCall) Do(f func(congestion.CongestionControl)) *MockSentPacketHandlerSetCongestionControlCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerSetCongestionControlCall) DoAndReturn(f func(congestion.CongestionControl)) *MockSentPacketHandlerSetCongestionControlCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // SetHandshakeConfirmed mocks base method. func (m *MockSentPacketHandler) SetHandshakeConfirmed() { m.ctrl.T.Helper() diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index a6b173419aa..cbcd3f09b9a 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -15,6 +15,7 @@ import ( reflect "reflect" quic "github.com/quic-go/quic-go" + congestion "github.com/quic-go/quic-go/congestion" qerr "github.com/quic-go/quic-go/internal/qerr" gomock "go.uber.org/mock/gomock" ) @@ -619,3 +620,39 @@ func (c *MockEarlyConnectionSendDatagramCall) DoAndReturn(f func([]byte) error) c.Call = c.Call.DoAndReturn(f) return c } + +// SetCongestionControl mocks base method. +func (m *MockEarlyConnection) SetCongestionControl(arg0 congestion.CongestionControl) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCongestionControl", arg0) +} + +// SetCongestionControl indicates an expected call of SetCongestionControl. +func (mr *MockEarlyConnectionMockRecorder) SetCongestionControl(arg0 any) *MockEarlyConnectionSetCongestionControlCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCongestionControl", reflect.TypeOf((*MockEarlyConnection)(nil).SetCongestionControl), arg0) + return &MockEarlyConnectionSetCongestionControlCall{Call: call} +} + +// MockEarlyConnectionSetCongestionControlCall wrap *gomock.Call +type MockEarlyConnectionSetCongestionControlCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionSetCongestionControlCall) Return() *MockEarlyConnectionSetCongestionControlCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionSetCongestionControlCall) Do(f func(congestion.CongestionControl)) *MockEarlyConnectionSetCongestionControlCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionSetCongestionControlCall) DoAndReturn(f func(congestion.CongestionControl)) *MockEarlyConnectionSetCongestionControlCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 7c4d8d4de8b..33043038967 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -3,16 +3,16 @@ package protocol import "time" // DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. -const DesiredReceiveBufferSize = (1 << 20) * 7 // 7 MB +const DesiredReceiveBufferSize = (1 << 20) * 8 // 8 MB // DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use. -const DesiredSendBufferSize = (1 << 20) * 7 // 7 MB +const DesiredSendBufferSize = (1 << 20) * 8 // 8 MB // InitialPacketSize is the initial (before Path MTU discovery) maximum packet size used. const InitialPacketSize = 1280 // MaxCongestionWindowPackets is the maximum congestion window in packet. -const MaxCongestionWindowPackets = 10000 +const MaxCongestionWindowPackets = 20000 // MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. const MaxUndecryptablePackets = 32 @@ -22,7 +22,7 @@ const MaxUndecryptablePackets = 32 const ConnectionFlowControlMultiplier = 1.5 // DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data -const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb +const DefaultInitialMaxStreamData = (1 << 20) * 2 // 2 MB // DefaultInitialMaxData is the connection-level flow control window for receiving data const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData @@ -78,7 +78,7 @@ const MaxNonAckElicitingAcks = 19 // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // prevents DoS attacks against the streamFrameSorter -const MaxStreamFrameSorterGaps = 1000 +const MaxStreamFrameSorterGaps = 20000 // MinStreamFrameBufferSize is the minimum data length of a received STREAM frame // that we use the buffer for. This protects against a DoS where an attacker would send us diff --git a/internal/utils/streamframe_interval.go b/internal/utils/streamframe_interval.go new file mode 100644 index 00000000000..78c411242b1 --- /dev/null +++ b/internal/utils/streamframe_interval.go @@ -0,0 +1,45 @@ +package utils + +import ( + "fmt" + + "github.com/quic-go/quic-go/internal/protocol" +) + +// ByteInterval is an interval from one ByteCount to the other +type ByteInterval struct { + Start protocol.ByteCount + End protocol.ByteCount +} + +func (i ByteInterval) Comp(v ByteInterval) int8 { + if i.Start < v.Start { + return -1 + } + if i.Start > v.Start { + return 1 + } + if i.End < v.End { + return -1 + } + if i.End > v.End { + return 1 + } + return 0 +} + +func (i ByteInterval) Match(n ByteInterval) int8 { + // check if there is an overlap + if i.Start <= n.End && i.End >= n.Start { + return 0 + } + if i.Start > n.End { + return 1 + } else { + return -1 + } +} + +func (i ByteInterval) String() string { + return fmt.Sprintf("[%d, %d]", i.Start, i.End) +} diff --git a/internal/utils/tree/tree.go b/internal/utils/tree/tree.go new file mode 100644 index 00000000000..e0b82899035 --- /dev/null +++ b/internal/utils/tree/tree.go @@ -0,0 +1,507 @@ +// Originated from https://github.com/ross-oreto/go-tree/blob/master/btree.go with the following changes: +// 1. Genericized the code +// 2. Added Match function for our frame sorter use case +// 3. Fixed a bug in deleteNode where in some cases the deleted flag was not set to true + +/* +Copyright (c) 2017 Ross Oreto + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package tree + +import ( + "fmt" +) + +type Val[T any] interface { + Comp(val T) int8 // returns 1 if > val, -1 if < val, 0 if equals to val + Match(cond T) int8 // returns 1 if > cond, -1 if < cond, 0 if matches cond +} + +// Btree represents an AVL tree +type Btree[T Val[T]] struct { + root *Node[T] + values []T + len int +} + +// Node represents a node in the tree with a value, left and right children, and a height/balance of the node. +type Node[T Val[T]] struct { + Value T + left, right *Node[T] + height int8 +} + +// New returns a new btree +func New[T Val[T]]() *Btree[T] { return new(Btree[T]).Init() } + +// Init initializes all values/clears the tree and returns the tree pointer +func (t *Btree[T]) Init() *Btree[T] { + t.root = nil + t.values = nil + t.len = 0 + return t +} + +// String returns a string representation of the tree values +func (t *Btree[T]) String() string { + return fmt.Sprint(t.Values()) +} + +// Empty returns true if the tree is empty +func (t *Btree[T]) Empty() bool { + return t.root == nil +} + +// NotEmpty returns true if the tree is not empty +func (t *Btree[T]) NotEmpty() bool { + return t.root != nil +} + +// Insert inserts a new value into the tree and returns the tree pointer +func (t *Btree[T]) Insert(value T) *Btree[T] { + added := false + t.root = insert(t.root, value, &added) + if added { + t.len++ + } + t.values = nil + return t +} + +func insert[T Val[T]](n *Node[T], value T, added *bool) *Node[T] { + if n == nil { + *added = true + return (&Node[T]{Value: value}).Init() + } + c := value.Comp(n.Value) + if c > 0 { + n.right = insert(n.right, value, added) + } else if c < 0 { + n.left = insert(n.left, value, added) + } else { + n.Value = value + *added = false + return n + } + + n.height = n.maxHeight() + 1 + c = balance(n) + + if c > 1 { + c = value.Comp(n.left.Value) + if c < 0 { + return n.rotateRight() + } else if c > 0 { + n.left = n.left.rotateLeft() + return n.rotateRight() + } + } else if c < -1 { + c = value.Comp(n.right.Value) + if c > 0 { + return n.rotateLeft() + } else if c < 0 { + n.right = n.right.rotateRight() + return n.rotateLeft() + } + } + return n +} + +// InsertAll inserts all the values into the tree and returns the tree pointer +func (t *Btree[T]) InsertAll(values []T) *Btree[T] { + for _, v := range values { + t.Insert(v) + } + return t +} + +// Contains returns true if the tree contains the specified value +func (t *Btree[T]) Contains(value T) bool { + return t.Get(value) != nil +} + +// ContainsAny returns true if the tree contains any of the values +func (t *Btree[T]) ContainsAny(values []T) bool { + for _, v := range values { + if t.Contains(v) { + return true + } + } + return false +} + +// ContainsAll returns true if the tree contains all of the values +func (t *Btree[T]) ContainsAll(values []T) bool { + for _, v := range values { + if !t.Contains(v) { + return false + } + } + return true +} + +// Get returns the node value associated with the search value +func (t *Btree[T]) Get(value T) *T { + var node *Node[T] + if t.root != nil { + node = t.root.get(value) + } + if node != nil { + return &node.Value + } + return nil +} + +func (t *Btree[T]) Match(cond T) []T { + var matches []T + if t.root != nil { + t.root.match(cond, &matches) + } + return matches +} + +// Len return the number of nodes in the tree +func (t *Btree[T]) Len() int { + return t.len +} + +// Head returns the first value in the tree +func (t *Btree[T]) Head() *T { + if t.root == nil { + return nil + } + beginning := t.root + for beginning.left != nil { + beginning = beginning.left + } + if beginning == nil { + for beginning.right != nil { + beginning = beginning.right + } + } + if beginning != nil { + return &beginning.Value + } + return nil +} + +// Tail returns the last value in the tree +func (t *Btree[T]) Tail() *T { + if t.root == nil { + return nil + } + beginning := t.root + for beginning.right != nil { + beginning = beginning.right + } + if beginning == nil { + for beginning.left != nil { + beginning = beginning.left + } + } + if beginning != nil { + return &beginning.Value + } + return nil +} + +// Values returns a slice of all the values in tree in order +func (t *Btree[T]) Values() []T { + if t.values == nil { + t.values = make([]T, t.len) + t.Ascend(func(n *Node[T], i int) bool { + t.values[i] = n.Value + return true + }) + } + return t.values +} + +// Delete deletes the node from the tree associated with the search value +func (t *Btree[T]) Delete(value T) *Btree[T] { + deleted := false + t.root = deleteNode(t.root, value, &deleted) + if deleted { + t.len-- + } + t.values = nil + return t +} + +// DeleteAll deletes the nodes from the tree associated with the search values +func (t *Btree[T]) DeleteAll(values []T) *Btree[T] { + for _, v := range values { + t.Delete(v) + } + return t +} + +func deleteNode[T Val[T]](n *Node[T], value T, deleted *bool) *Node[T] { + if n == nil { + return n + } + + c := value.Comp(n.Value) + + if c < 0 { + n.left = deleteNode(n.left, value, deleted) + } else if c > 0 { + n.right = deleteNode(n.right, value, deleted) + } else { + if n.left == nil { + t := n.right + n.Init() + *deleted = true + return t + } else if n.right == nil { + t := n.left + n.Init() + *deleted = true + return t + } + t := n.right.min() + n.Value = t.Value + n.right = deleteNode(n.right, t.Value, deleted) + *deleted = true + } + + // re-balance + if n == nil { + return n + } + n.height = n.maxHeight() + 1 + bal := balance(n) + if bal > 1 { + if balance(n.left) >= 0 { + return n.rotateRight() + } + n.left = n.left.rotateLeft() + return n.rotateRight() + } else if bal < -1 { + if balance(n.right) <= 0 { + return n.rotateLeft() + } + n.right = n.right.rotateRight() + return n.rotateLeft() + } + + return n +} + +// Pop deletes the last node from the tree and returns its value +func (t *Btree[T]) Pop() *T { + value := t.Tail() + if value != nil { + t.Delete(*value) + } + return value +} + +// Pull deletes the first node from the tree and returns its value +func (t *Btree[T]) Pull() *T { + value := t.Head() + if value != nil { + t.Delete(*value) + } + return value +} + +// NodeIterator expresses the iterator function used for traversals +type NodeIterator[T Val[T]] func(n *Node[T], i int) bool + +// Ascend performs an ascending order traversal of the tree calling the iterator function on each node +// the iterator will continue as long as the NodeIterator returns true +func (t *Btree[T]) Ascend(iterator NodeIterator[T]) { + var i int + if t.root != nil { + t.root.iterate(iterator, &i, true) + } +} + +// Descend performs a descending order traversal of the tree using the iterator +// the iterator will continue as long as the NodeIterator returns true +func (t *Btree[T]) Descend(iterator NodeIterator[T]) { + var i int + if t.root != nil { + t.root.rIterate(iterator, &i, true) + } +} + +// Debug prints out useful debug information about the tree for debugging purposes +func (t *Btree[T]) Debug() { + fmt.Println("----------------------------------------------------------------------------------------------") + if t.Empty() { + fmt.Println("tree is empty") + } else { + fmt.Println(t.Len(), "elements") + } + + t.Ascend(func(n *Node[T], i int) bool { + if t.root.Value.Comp(n.Value) == 0 { + fmt.Print("ROOT ** ") + } + n.Debug() + return true + }) + fmt.Println("----------------------------------------------------------------------------------------------") +} + +// Init initializes the values of the node or clears the node and returns the node pointer +func (n *Node[T]) Init() *Node[T] { + n.height = 1 + n.left = nil + n.right = nil + return n +} + +// String returns a string representing the node +func (n *Node[T]) String() string { + return fmt.Sprint(n.Value) +} + +// Debug prints out useful debug information about the tree node for debugging purposes +func (n *Node[T]) Debug() { + var children string + if n.left == nil && n.right == nil { + children = "no children |" + } else if n.left != nil && n.right != nil { + children = fmt.Sprint("left child:", n.left.String(), " right child:", n.right.String()) + } else if n.right != nil { + children = fmt.Sprint("right child:", n.right.String()) + } else { + children = fmt.Sprint("left child:", n.left.String()) + } + + fmt.Println(n.String(), "|", "height", n.height, "|", "balance", balance(n), "|", children) +} + +func height[T Val[T]](n *Node[T]) int8 { + if n != nil { + return n.height + } + return 0 +} + +func balance[T Val[T]](n *Node[T]) int8 { + if n == nil { + return 0 + } + return height(n.left) - height(n.right) +} + +func (n *Node[T]) get(val T) *Node[T] { + var node *Node[T] + c := val.Comp(n.Value) + if c < 0 { + if n.left != nil { + node = n.left.get(val) + } + } else if c > 0 { + if n.right != nil { + node = n.right.get(val) + } + } else { + node = n + } + return node +} + +func (n *Node[T]) match(cond T, results *[]T) { + c := n.Value.Match(cond) + if c > 0 { + if n.left != nil { + n.left.match(cond, results) + } + } else if c < 0 { + if n.right != nil { + n.right.match(cond, results) + } + } else { + // other matching nodes could be on both sides + if n.left != nil { + n.left.match(cond, results) + } + *results = append(*results, n.Value) + if n.right != nil { + n.right.match(cond, results) + } + } +} + +func (n *Node[T]) rotateRight() *Node[T] { + l := n.left + // Rotation + l.right, n.left = n, l.right + + // update heights + n.height = n.maxHeight() + 1 + l.height = l.maxHeight() + 1 + + return l +} + +func (n *Node[T]) rotateLeft() *Node[T] { + r := n.right + // Rotation + r.left, n.right = n, r.left + + // update heights + n.height = n.maxHeight() + 1 + r.height = r.maxHeight() + 1 + + return r +} + +func (n *Node[T]) iterate(iterator NodeIterator[T], i *int, cont bool) { + if n != nil && cont { + n.left.iterate(iterator, i, cont) + cont = iterator(n, *i) + *i++ + n.right.iterate(iterator, i, cont) + } +} + +func (n *Node[T]) rIterate(iterator NodeIterator[T], i *int, cont bool) { + if n != nil && cont { + n.right.iterate(iterator, i, cont) + cont = iterator(n, *i) + *i++ + n.left.iterate(iterator, i, cont) + } +} + +func (n *Node[T]) min() *Node[T] { + current := n + for current.left != nil { + current = current.left + } + return current +} + +func (n *Node[T]) maxHeight() int8 { + rh := height(n.right) + lh := height(n.left) + if rh > lh { + return rh + } + return lh +} diff --git a/internal/utils/tree/tree_match_test.go b/internal/utils/tree/tree_match_test.go new file mode 100644 index 00000000000..bb6c5b16229 --- /dev/null +++ b/internal/utils/tree/tree_match_test.go @@ -0,0 +1,95 @@ +package tree + +import ( + "testing" +) + +type interval struct { + start, end int +} + +func (i interval) Comp(ot interval) int8 { + if i.start < ot.start { + return -1 + } + if i.start > ot.start { + return 1 + } + if i.end < ot.end { + return -1 + } + if i.end > ot.end { + return 1 + } + return 0 +} + +func (i interval) Match(ot interval) int8 { + // Check for overlap + if i.start <= ot.end && i.end >= ot.start { + return 0 + } + if i.start > ot.end { + return 1 + } else { + return -1 + } +} + +func TestBtree(t *testing.T) { + values := []interval{ + {start: 9, end: 10}, + {start: 3, end: 4}, + {start: 1, end: 2}, + {start: 5, end: 6}, + {start: 7, end: 8}, + {start: 20, end: 100}, + {start: 11, end: 12}, + } + btree := New[interval]() + btree.InsertAll(values) + + expect, actual := len(values), btree.Len() + if actual != expect { + t.Error("length should equal", expect, "actual", actual) + } + + rs := btree.Match(interval{start: 1, end: 6}) + if len(rs) != 3 { + t.Errorf("expected 3 results, got %d", len(rs)) + } + if rs[0].start != 1 || rs[0].end != 2 { + t.Errorf("expected result 1 to be [1, 2], got %v", rs[0]) + } + if rs[1].start != 3 || rs[1].end != 4 { + t.Errorf("expected result 2 to be [3, 4], got %v", rs[1]) + } + if rs[2].start != 5 || rs[2].end != 6 { + t.Errorf("expected result 3 to be [5, 6], got %v", rs[2]) + } + + btree.Delete(interval{start: 5, end: 6}) + + rs = btree.Match(interval{start: 1, end: 6}) + if len(rs) != 2 { + t.Errorf("expected 2 results, got %d", len(rs)) + } + if rs[0].start != 1 || rs[0].end != 2 { + t.Errorf("expected result 1 to be [1, 2], got %v", rs[0]) + } + if rs[1].start != 3 || rs[1].end != 4 { + t.Errorf("expected result 2 to be [3, 4], got %v", rs[1]) + } + + btree.Delete(interval{start: 11, end: 12}) + + rs = btree.Match(interval{start: 12, end: 19}) + if len(rs) != 0 { + t.Errorf("expected 0 results, got %d", len(rs)) + } + + expect, actual = len(values)-2, btree.Len() + if actual != expect { + t.Error("length should equal", expect, "actual", actual) + } +} diff --git a/internal/utils/tree/tree_test.go b/internal/utils/tree/tree_test.go new file mode 100644 index 00000000000..547bf4d450c --- /dev/null +++ b/internal/utils/tree/tree_test.go @@ -0,0 +1,254 @@ +package tree + +import ( + "reflect" + "testing" +) + +type IntVal int + +func (i IntVal) Comp(v IntVal) int8 { + if i > v { + return 1 + } else if i < v { + return -1 + } else { + return 0 + } +} + +func (i IntVal) Match(v IntVal) int8 { + // Unused + return 0 +} + +type StringVal string + +func (i StringVal) Comp(v StringVal) int8 { + if i > v { + return 1 + } else if i < v { + return -1 + } else { + return 0 + } +} + +func (i StringVal) Match(v StringVal) int8 { + // Unused + return 0 +} + +func btreeInOrder(n int) *Btree[IntVal] { + btree := New[IntVal]() + for i := 1; i <= n; i++ { + btree.Insert(IntVal(i)) + } + return btree +} + +func btreeFixed[T Val[T]](values []T) *Btree[T] { + btree := New[T]() + btree.InsertAll(values) + return btree +} + +func TestBtree_Get(t *testing.T) { + values := []IntVal{9, 4, 2, 6, 8, 0, 3, 1, 7, 5} + btree := btreeFixed[IntVal](values).InsertAll(values) + + expect, actual := len(values), btree.Len() + if actual != expect { + t.Error("length should equal", expect, "actual", actual) + } + + expect2 := IntVal(2) + if btree.Get(expect2) == nil || *btree.Get(expect2) != expect2 { + t.Error("value should equal", expect2) + } +} + +func TestBtreeString_Get(t *testing.T) { + tree := New[StringVal]() + tree.Insert("Oreto").Insert("Michael").Insert("Ross") + + expect := StringVal("Ross") + if tree.Get(expect) == nil || *tree.Get(expect) != expect { + t.Error("value should equal", expect) + } +} + +func TestBtree_Contains(t *testing.T) { + btree := btreeInOrder(1000) + + test := IntVal(1) + if !btree.Contains(test) { + t.Error("tree should contain", test) + } + + test2 := []IntVal{1, 2, 3, 4} + if !btree.ContainsAll(test2) { + t.Error("tree should contain", test2) + } + + test2 = []IntVal{5} + if !btree.ContainsAny(test2) { + t.Error("tree should contain", test2) + } + + test2 = []IntVal{5000, 2000} + if btree.ContainsAny(test2) { + t.Error("tree should not contain any", test2) + } +} + +func TestBtree_String(t *testing.T) { + btree := btreeFixed[IntVal]([]IntVal{1, 2, 3, 4, 5, 6}) + s1 := btree.String() + s2 := "[1 2 3 4 5 6]" + if s1 != s2 { + t.Error(s1, "tree string representation should equal", s2) + } +} + +func TestBtree_Values(t *testing.T) { + const capacity = 3 + btree := btreeFixed[IntVal]([]IntVal{1, 2}) + + b := btree.Values() + c := []IntVal{1, 2} + if !reflect.DeepEqual(c, b) { + t.Error(c, "should equal", b) + } + btree.Insert(IntVal(3)) + + desc := [capacity]IntVal{} + btree.Descend(func(n *Node[IntVal], i int) bool { + desc[i] = n.Value + return true + }) + d := [capacity]IntVal{3, 2, 1} + if !reflect.DeepEqual(desc, d) { + t.Error(desc, "should equal", d) + } + + e := []IntVal{1, 2, 3} + for i, v := range btree.Values() { + if e[i] != v { + t.Error(e[i], "should equal", v) + } + } +} + +func TestBtree_Delete(t *testing.T) { + test := []IntVal{1, 2, 3} + btree := btreeFixed(test) + + btree.DeleteAll(test) + + if !btree.Empty() { + t.Error("tree should be empty") + } + + btree = btreeFixed(test) + pop := btree.Pop() + if pop == nil || *pop != IntVal(3) { + t.Error(pop, "should be 3") + } + pull := btree.Pull() + if pull == nil || *pull != IntVal(1) { + t.Error(pop, "should be 3") + } + if !btree.Delete(*btree.Pop()).Empty() { + t.Error("tree should be empty") + } + btree.Pop() + btree.Pull() +} + +func TestBtree_HeadTail(t *testing.T) { + btree := btreeFixed[IntVal]([]IntVal{1, 2, 3}) + if btree.Head() == nil || *btree.Head() != IntVal(1) { + t.Error("head element should be 1") + } + if btree.Tail() == nil || *btree.Tail() != IntVal(3) { + t.Error("head element should be 3") + } + btree.Init() + if btree.Head() != nil { + t.Error("head element should be nil") + } +} + +type TestKey1 struct { + Name string +} + +func (testkey TestKey1) Comp(tk TestKey1) int8 { + var c int8 + if testkey.Name > tk.Name { + c = 1 + } else if testkey.Name < tk.Name { + c = -1 + } + return c +} + +func (testkey TestKey1) Match(tk TestKey1) int8 { + // Unused + return 0 +} + +func TestBtree_CustomKey(t *testing.T) { + btree := New[TestKey1]() + btree.InsertAll([]TestKey1{ + {Name: "Ross"}, + {Name: "Michael"}, + {Name: "Angelo"}, + {Name: "Jason"}, + }) + + rootName := btree.root.Value.Name + if btree.root.Value.Name != "Michael" { + t.Error(rootName, "should equal Michael") + } + btree.Init() + btree.InsertAll([]TestKey1{ + {Name: "Ross"}, + {Name: "Michael"}, + {Name: "Angelo"}, + {Name: "Jason"}, + }) + btree.Debug() + s := btree.String() + test := "[{Angelo} {Jason} {Michael} {Ross}]" + if s != test { + t.Error(s, "should equal", test) + } + + btree.Delete(TestKey1{Name: "Michael"}) + if btree.Len() != 3 { + t.Error("tree length should be 3") + } + test = "Jason" + if btree.root.Value.Name != test { + t.Error(btree.root.Value, "root of the tree should be", test) + } + for !btree.Empty() { + btree.Delete(btree.root.Value) + } + btree.Debug() +} + +func TestBtree_Duplicates(t *testing.T) { + btree := New[IntVal]() + btree.InsertAll([]IntVal{ + 0, 2, 5, 10, 15, 20, 12, 14, + 13, 25, 0, 2, 5, 10, 15, 20, 12, 14, 13, 25, + }) + test := 10 + length := btree.Len() + if length != test { + t.Error(length, "tree length should be", test) + } +} diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go index 071fda9a05b..a8995053e94 100644 --- a/internal/wire/datagram_frame.go +++ b/internal/wire/datagram_frame.go @@ -11,7 +11,7 @@ import ( // By setting it to a large value, we allow all datagrams that fit into a QUIC packet. // The value is chosen such that it can still be encoded as a 2 byte varint. // This is a var and not a const so it can be set in tests. -var MaxDatagramSize protocol.ByteCount = 16383 +var MaxDatagramSize protocol.ByteCount = 1200 // A DatagramFrame is a DATAGRAM frame type DatagramFrame struct { diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 785ef5d3a2a..ad01990d0a0 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -14,6 +14,7 @@ import ( net "net" reflect "reflect" + congestion "github.com/quic-go/quic-go/congestion" qerr "github.com/quic-go/quic-go/internal/qerr" gomock "go.uber.org/mock/gomock" ) @@ -619,6 +620,42 @@ func (c *MockQUICConnSendDatagramCall) DoAndReturn(f func([]byte) error) *MockQU return c } +// SetCongestionControl mocks base method. +func (m *MockQUICConn) SetCongestionControl(arg0 congestion.CongestionControl) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCongestionControl", arg0) +} + +// SetCongestionControl indicates an expected call of SetCongestionControl. +func (mr *MockQUICConnMockRecorder) SetCongestionControl(arg0 any) *MockQUICConnSetCongestionControlCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCongestionControl", reflect.TypeOf((*MockQUICConn)(nil).SetCongestionControl), arg0) + return &MockQUICConnSetCongestionControlCall{Call: call} +} + +// MockQUICConnSetCongestionControlCall wrap *gomock.Call +type MockQUICConnSetCongestionControlCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnSetCongestionControlCall) Return() *MockQUICConnSetCongestionControlCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnSetCongestionControlCall) Do(f func(congestion.CongestionControl)) *MockQUICConnSetCongestionControlCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnSetCongestionControlCall) DoAndReturn(f func(congestion.CongestionControl)) *MockQUICConnSetCongestionControlCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // closeWithTransportError mocks base method. func (m *MockQUICConn) closeWithTransportError(arg0 qerr.TransportErrorCode) { m.ctrl.T.Helper() diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 302bafa245e..93df057e75f 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -154,6 +154,42 @@ func (c *MockSendConnRemoteAddrCall) DoAndReturn(f func() net.Addr) *MockSendCon return c } +// SetRemoteAddr mocks base method. +func (m *MockSendConn) SetRemoteAddr(arg0 net.Addr) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetRemoteAddr", arg0) +} + +// SetRemoteAddr indicates an expected call of SetRemoteAddr. +func (mr *MockSendConnMockRecorder) SetRemoteAddr(arg0 any) *MockSendConnSetRemoteAddrCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRemoteAddr", reflect.TypeOf((*MockSendConn)(nil).SetRemoteAddr), arg0) + return &MockSendConnSetRemoteAddrCall{Call: call} +} + +// MockSendConnSetRemoteAddrCall wrap *gomock.Call +type MockSendConnSetRemoteAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnSetRemoteAddrCall) Return() *MockSendConnSetRemoteAddrCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnSetRemoteAddrCall) Do(f func(net.Addr)) *MockSendConnSetRemoteAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnSetRemoteAddrCall) DoAndReturn(f func(net.Addr)) *MockSendConnSetRemoteAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Write mocks base method. func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error { m.ctrl.T.Helper() diff --git a/module_rename.py b/module_rename.py new file mode 100644 index 00000000000..0fc97c8ffa4 --- /dev/null +++ b/module_rename.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import argparse +import fileinput + + +PKG_ORIGINAL = "github.com/quic-go/quic-go" +PKG_NEW = "github.com/apernet/quic-go" + +EXTENSIONS = [".go", ".md", ".mod", ".sh"] + +parser = argparse.ArgumentParser() +parser.add_argument("-r", "--reverse", action="store_true") +args = parser.parse_args() + + +def replace_line(line): + if args.reverse: + return line.replace(PKG_NEW, PKG_ORIGINAL) + return line.replace(PKG_ORIGINAL, PKG_NEW) + + +for dirpath, dirnames, filenames in os.walk("."): + # Skip hidden directories like .git + dirnames[:] = [d for d in dirnames if not d[0] == "."] + filenames = [f for f in filenames if os.path.splitext(f)[1] in EXTENSIONS] + for filename in filenames: + file_path = os.path.join(dirpath, filename) + with fileinput.FileInput(file_path, inplace=True) as file: + for line in file: + print(replace_line(line), end="") diff --git a/send_conn.go b/send_conn.go index 498ed112b46..7f4c3c2c79e 100644 --- a/send_conn.go +++ b/send_conn.go @@ -2,6 +2,7 @@ package quic import ( "net" + "sync/atomic" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" @@ -13,6 +14,7 @@ type sendConn interface { Close() error LocalAddr() net.Addr RemoteAddr() net.Addr + SetRemoteAddr(net.Addr) capabilities() connCapabilities } @@ -21,7 +23,7 @@ type sconn struct { rawConn localAddr net.Addr - remoteAddr net.Addr + remoteAddr atomic.Value logger utils.Logger @@ -49,22 +51,25 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating l := len(oob) oob = append(oob, make([]byte, 64)...)[:l] - return &sconn{ + sc := &sconn{ rawConn: c, localAddr: localAddr, - remoteAddr: remote, + remoteAddr: atomic.Value{}, packetInfoOOB: oob, logger: logger, } + sc.SetRemoteAddr(remote) + return sc } func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { - err := c.writePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn) + remoteAddr := c.remoteAddr.Load().(net.Addr) + _, err := c.WritePacket(p, remoteAddr, c.packetInfoOOB, gsoSize, ecn) if err != nil && isGSOError(err) { // disable GSO for future calls c.gotGSOError = true if c.logger.Debug() { - c.logger.Debugf("GSO failed when sending to %s", c.remoteAddr) + c.logger.Debugf("GSO failed when sending to %s", remoteAddr) } // send out the packets one by one for len(p) > 0 { @@ -72,7 +77,7 @@ func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error { if l > int(gsoSize) { l = int(gsoSize) } - if err := c.writePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { + if _, err := c.WritePacket(p[:l], remoteAddr, c.packetInfoOOB, 0, ecn); err != nil { return err } p = p[l:] @@ -99,5 +104,12 @@ func (c *sconn) capabilities() connCapabilities { return capabilities } -func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *sconn) RemoteAddr() net.Addr { return c.remoteAddr.Load().(net.Addr) } func (c *sconn) LocalAddr() net.Addr { return c.localAddr } + +func (c *sconn) SetRemoteAddr(addr net.Addr) { + if addr == nil { + return + } + c.remoteAddr.Store(addr) +} diff --git a/sys_conn.go b/sys_conn.go index 71cc46070ce..a50c4c7a7c0 100644 --- a/sys_conn.go +++ b/sys_conn.go @@ -1,11 +1,7 @@ package quic import ( - "log" "net" - "os" - "strconv" - "strings" "syscall" "time" @@ -27,26 +23,8 @@ type OOBCapablePacketConn interface { var _ OOBCapablePacketConn = &net.UDPConn{} func wrapConn(pc net.PacketConn) (rawConn, error) { - if err := setReceiveBuffer(pc); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - setBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err) - }) - } - } - if err := setSendBuffer(pc); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - setBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err) - }) - } - } + _ = setReceiveBuffer(pc) + _ = setSendBuffer(pc) conn, ok := pc.(interface { SyscallConn() (syscall.RawConn, error) diff --git a/sys_conn_df_linux.go b/sys_conn_df_linux.go index f09eaa5dff8..5ba360b628c 100644 --- a/sys_conn_df_linux.go +++ b/sys_conn_df_linux.go @@ -29,7 +29,7 @@ func setDF(rawConn syscall.RawConn) (bool, error) { case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: - return false, errors.New("setting DF failed for both IPv4 and IPv6") + utils.DefaultLogger.Debugf("Setting DF failed for both IPv4 and IPv6.") } return true, nil } diff --git a/sys_conn_df_windows.go b/sys_conn_df_windows.go index e27635ec9c8..850d620ddd5 100644 --- a/sys_conn_df_windows.go +++ b/sys_conn_df_windows.go @@ -38,7 +38,7 @@ func setDF(rawConn syscall.RawConn) (bool, error) { case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: - return false, errors.New("setting DF failed for both IPv4 and IPv6") + utils.DefaultLogger.Debugf("Setting DF failed for both IPv4 and IPv6.") } return true, nil } diff --git a/sys_conn_helper_linux.go b/sys_conn_helper_linux.go index eec127197cd..eaf82ac8622 100644 --- a/sys_conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -102,7 +102,7 @@ func isGSOError(err error) bool { // which is a hard requirement of UDP_SEGMENT. See: // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 - return serr.Err == unix.EIO + return serr.Err == unix.EIO || serr.Err == unix.EINVAL } return false } diff --git a/transport.go b/transport.go index 059f30f5beb..ae5b75acda3 100644 --- a/transport.go +++ b/transport.go @@ -355,9 +355,6 @@ func (t *Transport) close(e error) { t.closed = true } -// only print warnings about the UDP receive buffer size once -var setBufferWarningOnce sync.Once - func (t *Transport) listen(conn rawConn) { defer close(t.listening) defer getMultiplexer().RemoveConn(t.Conn)