Skip to content

Commit

Permalink
Add traceWroteRequest & traceGotConn
Browse files Browse the repository at this point in the history
  • Loading branch information
RPRX committed Jan 1, 2025
1 parent 16b33a0 commit bee89bb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
13 changes: 12 additions & 1 deletion http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,20 @@ func (c *ClientConn) sendRequestBody(str Stream, body io.ReadCloser, contentLeng
return err
}

func traceWroteRequest(ctx context.Context, err error) {
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.WroteRequest != nil {
trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
}
}

func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
if err := str.SendRequestHeader(req); err != nil {
traceWroteRequest(req.Context(), err)
return nil, err
}
if req.Body == nil {
traceWroteRequest(req.Context(), nil)
str.Close()
} else {
// send the request body asynchronously
Expand All @@ -308,7 +317,9 @@ func (c *ClientConn) doRequest(req *http.Request, str *requestStream) (*http.Res
if req.ContentLength > 0 {
contentLength = req.ContentLength
}
if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
err := c.sendRequestBody(str, req.Body, contentLength)
traceWroteRequest(req.Context(), err)
if err != nil {
if c.logger != nil {
c.logger.Debug("error writing request", "error", err)
}
Expand Down
26 changes: 26 additions & 0 deletions http3/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"log/slog"
"net"
"net/http"
"net/http/httptrace"
"strings"
"sync"
"sync/atomic"
"time"

"golang.org/x/net/http/httpguts"

Expand Down Expand Up @@ -158,6 +160,29 @@ func (t *Transport) init() error {
return nil
}

// fakeConn is a wrapper for quic.EarlyConnection
// because the quic connection does not implement net.Conn.
type fakeConn struct {
conn quic.EarlyConnection
}

func (c *fakeConn) Close() error { panic("connection operation prohibited") }
func (c *fakeConn) Read(p []byte) (int, error) { panic("connection operation prohibited") }
func (c *fakeConn) Write(p []byte) (int, error) { panic("connection operation prohibited") }
func (c *fakeConn) SetDeadline(t time.Time) error { panic("connection operation prohibited") }
func (c *fakeConn) SetReadDeadline(t time.Time) error { panic("connection operation prohibited") }
func (c *fakeConn) SetWriteDeadline(t time.Time) error { panic("connection operation prohibited") }
func (c *fakeConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *fakeConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func traceGotConn(trace *httptrace.ClientTrace, conn quic.EarlyConnection, reused bool) {
if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{
Conn: &fakeConn{conn: conn},
Reused: reused,
})
}
}

// RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
t.initOnce.Do(func() { t.initErr = t.init() })
Expand Down Expand Up @@ -213,6 +238,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
t.removeClient(hostname)
return nil, cl.dialErr
}
traceGotConn(httptrace.ContextClientTrace(req.Context()), cl.conn, isReused)
defer cl.useCount.Add(-1)
rsp, err := cl.rt.RoundTrip(req)
if err != nil {
Expand Down

0 comments on commit bee89bb

Please sign in to comment.