-
Notifications
You must be signed in to change notification settings - Fork 1
/
ctxvml.go
127 lines (110 loc) · 3.79 KB
/
ctxvml.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package ctxvml
import (
"context"
"strings"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type ctxMarker struct{}
// VmlHeaders contains vml http headers.
type VmlHeaders struct {
Username string
OcrCache string
AuthMethod string
}
func (h VmlHeaders) OcrCacheAllow() bool {
return strings.EqualFold(h.OcrCache, "allow")
}
var (
ctxMarkerKey = &ctxMarker{}
)
// UnaryServerInterceptor for propagating client information
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx = extractMetadataToContext(ctx)
return handler(ctx, req)
}
}
// StreamServerInterceptor for propagating client information
// only on the first request on the stream
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := stream.Context()
s := serverStreamWithContext{
ServerStream: stream,
ctx: extractMetadataToContext(ctx),
}
return handler(srv, s)
}
}
type serverStreamWithContext struct {
grpc.ServerStream
ctx context.Context
}
func (ss serverStreamWithContext) Context() context.Context {
return ss.ctx
}
// finds caller information in the gRPC metadata and adds it to the context
func extractMetadataToContext(ctx context.Context) context.Context {
md, _ := metadata.FromIncomingContext(ctx)
headers := VmlHeaders{}
if mdValue, ok := md["vml-username"]; ok && len(mdValue) != 0 {
headers.Username = mdValue[0]
}
if mdValue, ok := md["vml-ocr-cache"]; ok && len(mdValue) != 0 {
headers.OcrCache = mdValue[0]
}
if mdValue, ok := md["vml-auth-method"]; ok && len(mdValue) != 0 {
headers.AuthMethod = mdValue[0]
}
if headers.Username != "" {
grpc_ctxtags.Extract(ctx).Set("username", headers.Username)
}
ctx = context.WithValue(ctx, ctxMarkerKey, headers)
return ctx
}
// UnaryClientInterceptor propagates any user information from the context
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
outGoingMetadata := packCallerMetadata(ctx)
for k, v := range outGoingMetadata {
ctx = metadata.AppendToOutgoingContext(ctx, k, v)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}
// StreamClientInterceptor propagates any user information from the context
func StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
outGoingMetadata := packCallerMetadata(ctx)
for k, v := range outGoingMetadata {
ctx = metadata.AppendToOutgoingContext(ctx, k, v)
}
return streamer(ctx, desc, cc, method, opts...)
}
}
// packCallerMetadata extracts caller specific values from the context,
// into a MD metadata struct that can be propagated with outgoing gRPC requests
func packCallerMetadata(ctx context.Context) map[string]string {
var md = map[string]string{}
headers := Extract(ctx)
md["vml-username"] = headers.Username
md["vml-ocr-cache"] = headers.OcrCache
md["vml-auth-method"] = headers.AuthMethod
return md
}
// Extract extracts metadate from the context.
func Extract(ctx context.Context) *VmlHeaders {
headers, ok := ctx.Value(ctxMarkerKey).(VmlHeaders)
if !ok {
return &VmlHeaders{}
}
return &headers
}
// WithValue Creates context with VML header values
func WithValue(ctx context.Context, headers VmlHeaders) context.Context {
return context.WithValue(ctx, ctxMarkerKey, headers)
}