diff --git a/instrumentation/net/http/httptrace/otelhttptrace/clienttrace.go b/instrumentation/net/http/httptrace/otelhttptrace/clienttrace.go index 67e03f24810..bb1f63a3bde 100644 --- a/instrumentation/net/http/httptrace/otelhttptrace/clienttrace.go +++ b/instrumentation/net/http/httptrace/otelhttptrace/clienttrace.go @@ -186,6 +186,9 @@ func NewClientTrace(ctx context.Context, opts ...ClientTraceOption) *httptrace.C } func (ct *clientTracer) start(hook, spanName string, attrs ...attribute.KeyValue) { + ct.mtx.Lock() + defer ct.mtx.Unlock() + if !ct.useSpans { if ct.root == nil { ct.root = trace.SpanFromContext(ct.Context) @@ -194,9 +197,6 @@ func (ct *clientTracer) start(hook, spanName string, attrs ...attribute.KeyValue return } - ct.mtx.Lock() - defer ct.mtx.Unlock() - if hookCtx, found := ct.activeHooks[hook]; !found { var sp trace.Span ct.activeHooks[hook], sp = ct.tr.Start(ct.getParentContext(hook), spanName, trace.WithAttributes(attrs...), trace.WithSpanKind(trace.SpanKindClient)) @@ -214,6 +214,13 @@ func (ct *clientTracer) start(hook, spanName string, attrs ...attribute.KeyValue } func (ct *clientTracer) end(hook string, err error, attrs ...attribute.KeyValue) { + ct.mtx.Lock() + defer ct.mtx.Unlock() + + if ct.root == nil { + return + } + if !ct.useSpans { if err != nil { attrs = append(attrs, attribute.String(hook+".error", err.Error())) @@ -222,8 +229,6 @@ func (ct *clientTracer) end(hook string, err error, attrs ...attribute.KeyValue) return } - ct.mtx.Lock() - defer ct.mtx.Unlock() if ctx, ok := ct.activeHooks[hook]; ok { span := trace.SpanFromContext(ctx) if err != nil { @@ -321,6 +326,9 @@ func (ct *clientTracer) tlsHandshakeDone(_ tls.ConnectionState, err error) { } func (ct *clientTracer) wroteHeaderField(k string, v []string) { + if ct.root == nil { + return + } if ct.useSpans && ct.span("http.headers") == nil { ct.start("http.headers", "http.headers") } diff --git a/instrumentation/net/http/httptrace/otelhttptrace/clienttrace_test.go b/instrumentation/net/http/httptrace/otelhttptrace/clienttrace_test.go index 966473dd0c3..4c9c438e950 100644 --- a/instrumentation/net/http/httptrace/otelhttptrace/clienttrace_test.go +++ b/instrumentation/net/http/httptrace/otelhttptrace/clienttrace_test.go @@ -8,6 +8,8 @@ import ( "fmt" "net/http" "net/http/httptrace" + "sync" + "testing" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) @@ -32,3 +34,45 @@ func ExampleNewClientTrace() { fmt.Println(resp.Status) } + +type zeroTripper struct{} + +func (zeroTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200}, nil +} + +var _ http.RoundTripper = zeroTripper{} + +// TestNewClientParallelismWithoutSubspans tests running many Gets on a client simultaneously, +// which would trigger a race condition if root were not protected by a mutex. +func TestNewClientParallelismWithoutSubspans(t *testing.T) { + t.Parallel() + + makeClientTrace := func(ctx context.Context) *httptrace.ClientTrace { + return NewClientTrace(ctx, WithoutSubSpans()) + } + + client := http.Client{ + Transport: otelhttp.NewTransport( + zeroTripper{}, + otelhttp.WithClientTrace(makeClientTrace), + ), + } + + var wg sync.WaitGroup + + for i := 1; i < 10000; i++ { + wg.Add(1) + go func() { + resp, err := client.Get("}}}}}") + if err != nil { + t.Error(err) + return + } + resp.Body.Close() + wg.Done() + }() + } + + wg.Wait() +}