Skip to content

Commit

Permalink
Add context propagation support (#783)
Browse files Browse the repository at this point in the history
* Add context propagation support

In the current implementation, weaver doesn't allow the user to
propagate context information. We recommend users to define a struct
that encapsulates the metadata information and add it as an argument to
the method. However, more and more users are asking for an option to
propagate metadata information using the context. This request comes
especially from users that are using gRPC to communicate between their
services, and gRPC provides a way to propagate metadata information
using the context.

This PR enables the users to propagate metadata information as a
map[string]string.

```main.go
// To attach metadata with key "foo" and value "bar" to the context, you can do:
ctx := context.Background()
ctx = metadata.NewContext(ctx, map[string]string{"foo": "bar"})

// To read the metadata value associated with a key "foo" in the context, you can do:
meta, found := metadata.FromContext(ctx)
if found {
  value := meta["foo"]
}
```

[1] https://pkg.go.dev/google.golang.org/grpc/metadata

* Header enc/dec into their own routines

* Last fixes
  • Loading branch information
rgrandl authored Aug 7, 2024
1 parent a89085a commit d613ffe
Show file tree
Hide file tree
Showing 11 changed files with 656 additions and 83 deletions.
6 changes: 6 additions & 0 deletions godeps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ github.com/ServiceWeaver/weaver/internal/net/call
errors
fmt
github.com/ServiceWeaver/weaver/internal/traceio
github.com/ServiceWeaver/weaver/metadata
github.com/ServiceWeaver/weaver/runtime/codegen
github.com/ServiceWeaver/weaver/runtime/logging
github.com/ServiceWeaver/weaver/runtime/retry
Expand Down Expand Up @@ -695,6 +696,9 @@ github.com/ServiceWeaver/weaver/internal/weaver
sync/atomic
syscall
time
github.com/ServiceWeaver/weaver/metadata
context
maps
github.com/ServiceWeaver/weaver/metrics
github.com/ServiceWeaver/weaver/runtime/metrics
github.com/ServiceWeaver/weaver/runtime/protos
Expand Down Expand Up @@ -1063,9 +1067,11 @@ github.com/ServiceWeaver/weaver/weavertest/internal/simple
errors
fmt
github.com/ServiceWeaver/weaver
github.com/ServiceWeaver/weaver/metadata
github.com/ServiceWeaver/weaver/runtime/codegen
go.opentelemetry.io/otel/codes
go.opentelemetry.io/otel/trace
maps
net/http
os
reflect
Expand Down
101 changes: 71 additions & 30 deletions internal/net/call/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,13 @@ import (
"sync/atomic"
"time"

"github.com/ServiceWeaver/weaver/runtime/codegen"
"github.com/ServiceWeaver/weaver/runtime/logging"
"github.com/ServiceWeaver/weaver/runtime/retry"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)

const (
// Size of the header included in each message.
msgHeaderSize = 16 + 8 + traceHeaderLen // handler_key + deadline + trace_context
)

// Connection allows a client to send RPCs.
type Connection interface {
// Call makes an RPC over a Connection.
Expand Down Expand Up @@ -385,25 +381,29 @@ func (rc *reconnectingConnection) Call(ctx context.Context, h MethodKey, arg []b
}

func (rc *reconnectingConnection) callOnce(ctx context.Context, h MethodKey, arg []byte, opts CallOptions) ([]byte, error) {
var hdr [msgHeaderSize]byte
copy(hdr[0:], h[:])
var micros int64
deadline, haveDeadline := ctx.Deadline()
if haveDeadline {
// Send the deadline in the header. We use the relative time instead
// of absolute in case there is significant clock skew. This does mean
// that we will not count transmission delay against the deadline.
micros := time.Until(deadline).Microseconds()
micros = time.Until(deadline).Microseconds()
if micros <= 0 {
// Fail immediately without attempting to send a zero or negative
// deadline to the server which will be misinterpreted.
<-ctx.Done()
return nil, ctx.Err()
}
binary.LittleEndian.PutUint64(hdr[16:], uint64(micros))
}

// Send trace information in the header.
writeTraceContext(ctx, hdr[24:])
// Encode the header.
hdr := encodeHeader(ctx, h, micros)

// Note that we send the header and the payload as follows:
// [header_length][encoded_header][payload]
var hdrLen [hdrLenLen]byte
binary.LittleEndian.PutUint32(hdrLen[:], uint32(len(hdr)))
hdrSlice := append(hdrLen[:], hdr...)

rpc := &call{}
rpc.doneSignal = make(chan struct{})
Expand All @@ -413,7 +413,7 @@ func (rc *reconnectingConnection) callOnce(ctx context.Context, h MethodKey, arg
if err != nil {
return nil, err
}
if err := writeMessage(nc, &conn.wlock, requestMessage, rpc.id, hdr[:], arg, rc.opts.WriteFlattenLimit); err != nil {
if err := writeMessage(nc, &conn.wlock, requestMessage, rpc.id, hdrSlice, arg, rc.opts.WriteFlattenLimit); err != nil {
conn.shutdown("client send request", err)
conn.endCall(rpc)
return nil, fmt.Errorf("%w: %s", CommunicationError, err)
Expand Down Expand Up @@ -942,35 +942,31 @@ func (c *serverConnection) readRequests(ctx context.Context, hmap *HandlerMap, o
// runHandler runs an application specified RPC handler at the server side.
// The result (or error) from the handler is sent back to the client over c.
func (c *serverConnection) runHandler(hmap *HandlerMap, id uint64, msg []byte) {
// Extract request header from front of payload.
if len(msg) < msgHeaderSize {
msgLen := uint32(len(msg))
if msgLen < hdrLenLen {
c.shutdown("server handler", fmt.Errorf("missing request header length"))
return
}

// Get the header length.
hdrLen := binary.LittleEndian.Uint32(msg[:hdrLenLen])
hdrEndOffset := hdrLenLen + hdrLen
if msgLen < hdrEndOffset {
c.shutdown("server handler", fmt.Errorf("missing request header"))
return
}

// Extract handler key.
var hkey MethodKey
copy(hkey[:], msg)
// Extracts header information.
ctx, hkey, micros, sc := decodeHeader(msg[hdrLenLen:hdrEndOffset])

// Extract the method name
// Extracts the method name.
methodName := hmap.names[hkey]
if methodName == "" {
methodName = "handler"
} else {
methodName = logging.ShortenComponent(methodName)
}

// Extract trace context and create a new child span to trace the method
// call on the server.
ctx := context.Background()
span := trace.SpanFromContext(ctx) // noop span
if sc := readTraceContext(msg[24:]); sc.IsValid() {
ctx, span = c.opts.Tracer.Start(trace.ContextWithSpanContext(ctx, sc), methodName, trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
}

// Add deadline information from the header to the context.
micros := binary.LittleEndian.Uint64(msg[16:])
var cancelFunc func()
if micros != 0 {
deadline := time.Now().Add(time.Microsecond * time.Duration(micros))
Expand All @@ -984,8 +980,19 @@ func (c *serverConnection) runHandler(hmap *HandlerMap, id uint64, msg []byte) {
}
}()

// Create a new child span to trace the method call on the server.
span := trace.SpanFromContext(ctx) // noop span
if sc != nil {
if !sc.IsValid() {
c.shutdown("server handler", fmt.Errorf("invalid span context"))
return
}
ctx, span = c.opts.Tracer.Start(trace.ContextWithSpanContext(ctx, *sc), methodName, trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
}

// Call the handler passing it the payload.
payload := msg[msgHeaderSize:]
payload := msg[hdrEndOffset:]
var err error
var result []byte
fn, ok := hmap.handlers[hkey]
Expand Down Expand Up @@ -1051,6 +1058,40 @@ func (c *serverConnection) shutdown(details string, err error) {
}
}

// encodeHeader encodes the header information that is propagated by each message.
func encodeHeader(ctx context.Context, h MethodKey, micros int64) []byte {
enc := codegen.NewEncoder()
copy(enc.Grow(len(h)), h[:])
enc.Int64(micros)

// Send trace information in the header.
writeTraceContext(ctx, enc)

// Send context metadata in the header.
writeContextMetadata(ctx, enc)

return enc.Data()
}

// decodeHeader extracts the encoded header information.
func decodeHeader(hdr []byte) (context.Context, MethodKey, int64, *trace.SpanContext) {
dec := codegen.NewDecoder(hdr)

// Extract handler key.
var hkey MethodKey
copy(hkey[:], dec.Read(len(hkey)))

// Extract deadline information.
micros := dec.Int64()

// Extract trace context information.
sc := readTraceContext(dec)

// Extract metadata context information if any.
ctx := readContextMetadata(context.Background(), dec)
return ctx, hkey, micros, sc
}

func logError(logger *slog.Logger, details string, err error) {
if errors.Is(err, context.Canceled) ||
errors.Is(err, io.EOF) ||
Expand Down
54 changes: 54 additions & 0 deletions internal/net/call/metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package call

import (
"context"

"github.com/ServiceWeaver/weaver/metadata"
"github.com/ServiceWeaver/weaver/runtime/codegen"
)

// writeContextMetadata serializes the context metadata (if any) into enc.
func writeContextMetadata(ctx context.Context, enc *codegen.Encoder) {
m, found := metadata.FromContext(ctx)
if !found {
enc.Bool(false)
return
}
enc.Bool(true)
enc.Len(len(m))
for k, v := range m {
enc.String(k)
enc.String(v)
}
}

// readContextMetadata returns the context metadata (if any) stored in dec.
func readContextMetadata(ctx context.Context, dec *codegen.Decoder) context.Context {
hasMeta := dec.Bool()
if !hasMeta {
return ctx
}
n := dec.Len()
res := make(map[string]string, n)
var k, v string
for i := 0; i < n; i++ {
k = dec.String()
v = dec.String()
res[k] = v
}
return metadata.NewContext(ctx, res)
}
19 changes: 15 additions & 4 deletions internal/net/call/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ const (

const currentVersion = initialVersion

const hdrLenLen = uint32(4) // size of the header length included in each message

// # Message formats
//
// All messages have the following format:
Expand All @@ -60,10 +62,19 @@ const currentVersion = initialVersion
// version [4]byte
//
// requestMessage:
// headerKey [16]byte -- fingerprint of method name
// deadline [8]byte -- zero, or deadline in microseconds
// traceContext [25]byte -- zero, or trace context
// remainder -- call argument serialization
// headerLen [4]byte -- length of the encoded header
// header [headerLen]byte -- encoded header information
// payload -- call argument serialization
//
// The header is encoded using Service Weaver's encoding format for a type that
// looks like:
//
// struct header {
// MethodKey [16]byte
// Deadline int64
// TraceContext [25]byte
// MetadataContext map[string]string
// }
//
// responseMessage:
// payload holds call result serialization
Expand Down
38 changes: 23 additions & 15 deletions internal/net/call/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,46 @@ package call
import (
"context"

"github.com/ServiceWeaver/weaver/runtime/codegen"
"go.opentelemetry.io/otel/trace"
)

const traceHeaderLen = 25

// writeTraceContext serializes the trace context (if any) contained in ctx
// into b.
// REQUIRES: len(b) >= traceHeaderLen
func writeTraceContext(ctx context.Context, b []byte) {
// into enc.
func writeTraceContext(ctx context.Context, enc *codegen.Encoder) {
sc := trace.SpanContextFromContext(ctx)
if !sc.IsValid() {
enc.Bool(false)
return
}
enc.Bool(true)

// Send trace information in the header.
// TODO(spetrovic): Confirm that we don't need to bother with TraceState,
// which seems to be used for storing vendor-specific information.
traceID := sc.TraceID()
spanID := sc.SpanID()
copy(b, traceID[:])
copy(b[16:], spanID[:])
b[24] = byte(sc.TraceFlags())
copy(enc.Grow(len(traceID)), traceID[:])
copy(enc.Grow(len(spanID)), spanID[:])
enc.Byte(byte(sc.TraceFlags()))
}

// readTraceContext returns a span context with tracing information stored in b.
// REQUIRES: len(b) >= traceHeaderLen
func readTraceContext(b []byte) trace.SpanContext {
// readTraceContext returns a span context with tracing information stored in dec.
func readTraceContext(dec *codegen.Decoder) *trace.SpanContext {
hasTrace := dec.Bool()
if !hasTrace {
return nil
}
var traceID trace.TraceID
var spanID trace.SpanID
traceID = *(*trace.TraceID)(dec.Read(len(traceID)))
spanID = *(*trace.SpanID)(dec.Read(len(spanID)))
cfg := trace.SpanContextConfig{
TraceID: *(*trace.TraceID)(b[:16]),
SpanID: *(*trace.SpanID)(b[16:24]),
TraceFlags: trace.TraceFlags(b[24]),
TraceID: traceID,
SpanID: spanID,
TraceFlags: trace.TraceFlags(dec.Byte()),
Remote: true,
}
return trace.NewSpanContext(cfg)
trace := trace.NewSpanContext(cfg)
return &trace
}
11 changes: 6 additions & 5 deletions internal/net/call/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"testing"

"github.com/ServiceWeaver/weaver/runtime/codegen"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
)
Expand All @@ -36,14 +37,14 @@ func TestTraceSerialization(t *testing.T) {
})

// Serialize the trace context.
var b [25]byte
writeTraceContext(
trace.ContextWithSpanContext(context.Background(), span), b[:])
enc := codegen.NewEncoder()
writeTraceContext(trace.ContextWithSpanContext(context.Background(), span), enc)

// Deserialize the trace context.
actual := readTraceContext(b[:])
dec := codegen.NewDecoder(enc.Data())
actual := readTraceContext(dec)
expect := span.WithRemote(true)
if !expect.Equal(actual) {
if !expect.Equal(*actual) {
want, _ := json.Marshal(expect)
got, _ := json.Marshal(actual)
t.Errorf("span context diff, want %q, got %q", want, got)
Expand Down
Loading

0 comments on commit d613ffe

Please sign in to comment.