Skip to content

Commit 788417d

Browse files
committed
[FLINK-26375][statefun-golang-sdk] Fix Statefun Golang SDK to return nil from Context.Caller when there is no caller
Change handler.go to only populate statefunContext.caller when a caller is present in the invocation Add unit tests
1 parent 725202c commit 788417d

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

statefun-sdk-go/v3/pkg/statefun/handler.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,10 @@ func (h *handler) invoke(ctx context.Context, toFunction *protocol.ToFunction) (
225225
var cancel context.CancelFunc
226226
sContext.Context, cancel = context.WithCancel(ctx)
227227

228-
var caller Address
229228
if invocation.Caller != nil {
230-
caller = addressFromInternal(invocation.Caller)
229+
caller := addressFromInternal(invocation.Caller)
230+
sContext.caller = &caller
231231
}
232-
sContext.caller = &caller
233232
msg := Message{
234233
target: batch.Target,
235234
typedValue: invocation.Argument,
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package statefun
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol"
8+
"github.com/stretchr/testify/assert"
9+
"google.golang.org/protobuf/proto"
10+
)
11+
12+
// helper to create a protocol Address from an Address
13+
func toProtocolAddress(address *Address) *protocol.Address {
14+
if address != nil {
15+
return &protocol.Address{
16+
Namespace: address.FunctionType.GetNamespace(),
17+
Type: address.FunctionType.GetType(),
18+
Id: address.Id,
19+
}
20+
} else {
21+
return nil
22+
}
23+
}
24+
25+
// helper to create a handler and invoke the function
26+
func invokeStatefulFunction(ctx context.Context, target *Address, caller *Address, argument *protocol.TypedValue, statefulFunction StatefulFunction) error {
27+
28+
builder := StatefulFunctionsBuilder()
29+
err := builder.WithSpec(StatefulFunctionSpec{
30+
FunctionType: target.FunctionType,
31+
Function: statefulFunction,
32+
})
33+
if err != nil {
34+
return err
35+
}
36+
37+
toFunction := protocol.ToFunction{
38+
Request: &protocol.ToFunction_Invocation_{
39+
Invocation: &protocol.ToFunction_InvocationBatchRequest{
40+
Target: toProtocolAddress(target),
41+
Invocations: []*protocol.ToFunction_Invocation{
42+
{
43+
Caller: toProtocolAddress(caller),
44+
Argument: argument,
45+
},
46+
},
47+
},
48+
},
49+
}
50+
51+
bytes, err := proto.Marshal(&toFunction)
52+
if err != nil {
53+
return err
54+
}
55+
56+
_, err = builder.AsHandler().Invoke(ctx, bytes)
57+
if err != nil {
58+
return err
59+
}
60+
61+
return nil
62+
}
63+
64+
func TestStatefunHandler_WithNoCaller_ContextCallerIsNil(t *testing.T) {
65+
66+
target := Address{FunctionType: TypeNameFrom("namespace/function1"), Id: "1"}
67+
68+
statefulFunction := func(ctx Context, message Message) error {
69+
assert.Nil(t, ctx.Caller())
70+
return nil
71+
}
72+
73+
err := invokeStatefulFunction(context.Background(), &target, nil, nil, StatefulFunctionPointer(statefulFunction))
74+
assert.Nil(t, err)
75+
}
76+
77+
func TestStatefunHandler_WithCaller_ContextCallerIsCorrect(t *testing.T) {
78+
79+
target := Address{FunctionType: TypeNameFrom("namespace/function1"), Id: "1"}
80+
caller := Address{FunctionType: TypeNameFrom("namespace/function2"), Id: "2"}
81+
82+
statefulFunction := func(ctx Context, message Message) error {
83+
assert.Equal(t, caller.String(), ctx.Caller().String())
84+
return nil
85+
}
86+
87+
err := invokeStatefulFunction(context.Background(), &target, &caller, nil, StatefulFunctionPointer(statefulFunction))
88+
assert.Nil(t, err)
89+
}

0 commit comments

Comments
 (0)