Skip to content

Commit

Permalink
feat: Add support for strongly typed function signature (GoogleCloudP…
Browse files Browse the repository at this point in the history
  • Loading branch information
kappratiksha authored Feb 7, 2023
1 parent e038fee commit 06264b6
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 2 deletions.
67 changes: 67 additions & 0 deletions funcframework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const (
fnErrorMessageStderrTmpl = "Function error: %v"
)

var errorType = reflect.TypeOf((*error)(nil)).Elem()

// recoverPanic recovers from a panic in a consistent manner. panicSrc should
// describe what was happening when the panic was encountered, for example
// "user function execution". w is an http.ResponseWriter to write a generic
Expand Down Expand Up @@ -168,6 +170,12 @@ func wrapFunction(fn *registry.RegisteredFunction) (http.Handler, error) {
return nil, fmt.Errorf("unexpected error in wrapEventFunction: %v", err)
}
return handler, nil
} else if fn.TypedFn != nil {
handler, err := wrapTypedFunction(fn.TypedFn)
if err != nil {
return nil, fmt.Errorf("unexpected error in wrapTypedFunction: %v", err)
}
return handler, nil
}
return nil, fmt.Errorf("missing function entry in %v", fn)
}
Expand Down Expand Up @@ -206,6 +214,65 @@ func wrapEventFunction(fn interface{}) (http.Handler, error) {
}), nil
}

func wrapTypedFunction(fn interface{}) (http.Handler, error) {
inputType, err := validateTypedFunction(fn)
if err != nil {
return nil, err
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := readHTTPRequestBody(r)
if err != nil {
writeHTTPErrorResponse(w, http.StatusBadRequest, crashStatus, fmt.Sprintf("%v", err))
return
}
argVal := inputType

if err := json.Unmarshal(body, argVal.Interface()); err != nil {
writeHTTPErrorResponse(w, http.StatusBadRequest, crashStatus, fmt.Sprintf("Error while converting input data. %s", err.Error()))
return
}

defer recoverPanic(w, "user function execution")
funcReturn := reflect.ValueOf(fn).Call([]reflect.Value{
argVal.Elem(),
})

handleTypedReturn(w, funcReturn)
}), nil
}

func handleTypedReturn(w http.ResponseWriter, funcReturn []reflect.Value) {
if len(funcReturn) == 0 {
return
}
errorVal := funcReturn[len(funcReturn)-1].Interface() // last return must be of type error
if errorVal != nil && reflect.TypeOf(errorVal).AssignableTo(errorType) {
writeHTTPErrorResponse(w, http.StatusInternalServerError, errorStatus, fmtFunctionError(errorVal))
return
}

firstVal := funcReturn[0].Interface()
if !reflect.TypeOf(firstVal).AssignableTo(errorType) {
returnVal, _ := json.Marshal(firstVal)
fmt.Fprintf(w, string(returnVal))
}
}

func validateTypedFunction(fn interface{}) (*reflect.Value, error) {
ft := reflect.TypeOf(fn)
if ft.NumIn() != 1 {
return nil, fmt.Errorf("expected function to have one parameters, found %d", ft.NumIn())
}
if ft.NumOut() > 2 {
return nil, fmt.Errorf("expected function to have maximum two return values")
}
if ft.NumOut() > 0 && !ft.Out(ft.NumOut()-1).AssignableTo(errorType) {
return nil, fmt.Errorf("expected last return type to be of error")
}
var inputType = reflect.New(ft.In(0))
return &inputType, nil
}

func wrapCloudEventFunction(ctx context.Context, fn func(context.Context, cloudevents.Event) error) (http.Handler, error) {
p, err := cloudevents.NewHTTP()
if err != nil {
Expand Down
190 changes: 190 additions & 0 deletions funcframework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,200 @@ type customStruct struct {
Name string `json:"name"`
}

type testStruct struct {
Age int
Name string
}

type eventData struct {
Data string `json:"data"`
}

func TestRegisterTypedFunction(t *testing.T) {
var tests = []struct {
name string
path string
body []byte
fn interface{}
target string
status int
header string
ceHeaders map[string]string
wantResp string
wantStderr string
}{
{
name: "TestTypedFunction_typed",
body: []byte(`{"id": 12345,"name": "custom"}`),
fn: func(s customStruct) (customStruct, error) {
return s, nil
},
status: http.StatusOK,
header: "",
wantResp: `{"id":12345,"name":"custom"}`,
},
{
name: "TestTypedFunction_no_return",
body: []byte(`{"id": 12345,"name": "custom"}`),
fn: func(s customStruct) {

},
status: http.StatusOK,
header: "",
wantResp: "",
},
{
name: "TestTypedFunction_untagged_struct",
body: []byte(`{"Age": 30,"Name": "john"}`),
fn: func(s testStruct) (testStruct, error) {
return s, nil
},
status: http.StatusOK,
header: "",
wantResp: `{"Age":30,"Name":"john"}`,
},
{
name: "TestTypedFunction_two_returns",
body: []byte(`{"id": 12345,"name": "custom"}`),
fn: func(s customStruct) (customStruct, error) {
return s, nil
},
status: http.StatusOK,
header: "",
wantResp: `{"id":12345,"name":"custom"}`,
},
{
name: "TestTypedFunction_return_int",
body: []byte(`{"id": 12345,"name": "custom"}`),
fn: func(s customStruct) (int, error) {
return s.ID, nil
},
status: http.StatusOK,
header: "",
wantResp: "12345",
},
{
name: "TestTypedFunction_different_types",
body: []byte(`{"id": 12345,"name": "custom"}`),
fn: func(s customStruct) (testStruct, error) {
var t = testStruct{99, "John"}
return t, nil
},
status: http.StatusOK,
header: "",
wantResp: `{"Age":99,"Name":"John"}`,
},
{
name: "TestTypedFunction_return_error",
body: []byte(`{"id": 12345,"name": "custom"}`),
fn: func(s customStruct) error {
return fmt.Errorf("Some error message")
},
status: http.StatusInternalServerError,
header: "error",
wantResp: fmt.Sprintf(fnErrorMessageStderrTmpl, "Some error message"),
wantStderr: "Some error message",
},
{
name: "TestTypedFunction_data_error",
body: []byte(`{"id": 12345,"name": 5}`),
fn: func(s customStruct) (customStruct, error) {
return s, nil
},
status: http.StatusBadRequest,
header: "crash",
wantStderr: "while converting input data",
},
{
name: "TestTypedFunction_func_error",
body: []byte(`{"id": 0,"name": "john"}`),
fn: func(s customStruct) (customStruct, error) {
s.ID = 10 / s.ID
return s, nil
},
status: http.StatusInternalServerError,
header: "crash",
wantStderr: "A panic occurred during user function execution",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
defer cleanup()
if len(tc.target) > 0 {
os.Setenv("FUNCTION_TARGET", tc.target)
}
functions.Typed(tc.name, tc.fn)
if _, ok := registry.Default().GetRegisteredFunction(tc.name); !ok {
t.Fatalf("could not get registered function: %s", tc.name)
}

origStderrPipe := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
defer func() { os.Stderr = origStderrPipe }()

server, err := initServer()
if err != nil {
t.Fatalf("initServer(): %v", err)
}
srv := httptest.NewServer(server)
defer srv.Close()

req, err := http.NewRequest("POST", srv.URL+"/"+tc.name, bytes.NewBuffer(tc.body))
if err != nil {
t.Fatalf("error creating HTTP request for test: %v", err)
}
for k, v := range tc.ceHeaders {
req.Header.Add(k, v)
}

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("client.Do(%s): %v", tc.name, err)
}

if err := w.Close(); err != nil {
t.Fatalf("failed to close stderr write pipe: %v", err)
}

stderr, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("failed to read stderr read pipe: %v", err)
}

if err := r.Close(); err != nil {
t.Fatalf("failed to close stderr read pipe: %v", err)
}

if !strings.Contains(string(stderr), tc.wantStderr) {
t.Errorf("stderr mismatch, got: %q, must contain: %q", string(stderr), tc.wantStderr)
}

if tc.wantStderr != "" && !strings.Contains(string(stderr), tc.wantStderr) {
t.Errorf("stderr mismatch, got: %q, must contain: %q", string(stderr), tc.wantStderr)
}

gotBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read got request body: %v", err)
}

if tc.wantResp != "" && strings.TrimSpace(string(gotBody)) != tc.wantResp {
t.Errorf("TestTypedFunction(%s): response body = %q, want %q on error status code %d.", tc.name, gotBody, tc.wantResp, tc.status)
}

if resp.StatusCode != tc.status {
t.Errorf("TestTypedFunction(%s): response status = %v, want %v, %q.", tc.name, resp.StatusCode, tc.status, string(gotBody))
}
if resp.Header.Get(functionStatusHeader) != tc.header {
t.Errorf("TestTypedFunction(%s): response header = %q, want %q", tc.name, resp.Header.Get(functionStatusHeader), tc.header)
}
})
}
}

func TestRegisterEventFunctionContext(t *testing.T) {
var tests = []struct {
name string
Expand Down
10 changes: 10 additions & 0 deletions functions/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,13 @@ func CloudEvent(name string, fn func(context.Context, cloudevents.Event) error)
log.Fatalf("failure to register function: %s", err)
}
}

// Typed registers a Typed function that becomes the function handler
// served at "/" when environment variable `FUNCTION_TARGET=name`
// This function takes a strong type T as an input and can return a strong type T,
// built in types, nil and/or error as an output
func Typed(name string, fn interface{}) {
if err := registry.Default().RegisterTyped(fn, registry.WithName(name)); err != nil {
log.Fatalf("failure to register function: %s", err)
}
}
10 changes: 8 additions & 2 deletions internal/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type RegisteredFunction struct {
CloudEventFn func(context.Context, cloudevents.Event) error // Optional: The user's CloudEvent function
HTTPFn func(http.ResponseWriter, *http.Request) // Optional: The user's HTTP function
EventFn interface{} // Optional: The user's Event function
TypedFn interface{} // Optional: The user's typed function
}

// Option is an option used when registering a function.
Expand Down Expand Up @@ -62,16 +63,21 @@ func (r *Registry) RegisterHTTP(fn func(http.ResponseWriter, *http.Request), opt
return r.register(&RegisteredFunction{HTTPFn: fn}, options...)
}

// RegistryCloudEvent registers a CloudEvent function.
// RegisterCloudEvent registers a CloudEvent function.
func (r *Registry) RegisterCloudEvent(fn func(context.Context, cloudevents.Event) error, options ...Option) error {
return r.register(&RegisteredFunction{CloudEventFn: fn}, options...)
}

// RegistryCloudEvent registers a Event function.
// RegisterEvent registers an Event function.
func (r *Registry) RegisterEvent(fn interface{}, options ...Option) error {
return r.register(&RegisteredFunction{EventFn: fn}, options...)
}

// RegisterTyped registers a strongly typed function.
func (r *Registry) RegisterTyped(fn interface{}, options ...Option) error {
return r.register(&RegisteredFunction{TypedFn: fn}, options...)
}

func (r *Registry) register(function *RegisteredFunction, options ...Option) error {
for _, o := range options {
o(function)
Expand Down

0 comments on commit 06264b6

Please sign in to comment.