diff --git a/lambda/handler.go b/lambda/handler.go index e4cfaf7a..acdc7cac 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -28,6 +28,7 @@ type handlerOptions struct { jsonResponseIndentValue string enableSIGTERM bool sigtermCallbacks []func() + setupFuncs []func() error } type Option func(*handlerOptions) @@ -102,6 +103,15 @@ func WithEnableSIGTERM(callbacks ...func()) Option { }) } +// WithSetup enables capturing of errors or panics that occur before the function is ready to handle invokes. +// The provided functions will be run a single time, in order, before the runtime reports itself ready to recieve invokes. +// If any of the provided functions returns an error, or panics, the error will be serialized and reported to the Runtime API. +func WithSetup(funcs ...func() error) Option { + return Option(func(h *handlerOptions) { + h.setupFuncs = append(h.setupFuncs, funcs...) + }) +} + // handlerTakesContext returns whether the handler takes a context.Context as its first argument. func handlerTakesContext(handler reflect.Type) (bool, error) { switch handler.NumIn() { diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index 9e2d6598..b243eecc 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -31,6 +31,10 @@ func unixMS(ms int64) time.Time { func startRuntimeAPILoop(api string, handler Handler) error { client := newRuntimeAPIClient(api) h := newHandler(handler) + + if err := handleSetup(client, h); err != nil { + return err + } for { invoke, err := client.next() if err != nil { @@ -42,6 +46,21 @@ func startRuntimeAPILoop(api string, handler Handler) error { } } +// handleSetup returns an error if any of the handler's optional setup functions return and error or panic +func handleSetup(client *runtimeAPIClient, handler *handlerOptions) error { + for _, setup := range handler.setupFuncs { + if setupErr := callSetupFunc(setup); setupErr != nil { + errorPayload := safeMarshal(setupErr) + log.Printf("%s", errorPayload) + if err := client.initError(bytes.NewReader(errorPayload), contentTypeJSON); err != nil { + return fmt.Errorf("unexpected error occurred when sending the setup error to the API: %v", err) + } + return fmt.Errorf("setting up the handler function resulted in an error, the process should exit") + } + } + return nil +} + // handleInvoke returns an error if the function panics, or some other non-recoverable error occurred func handleInvoke(invoke *invoke, handler *handlerOptions) error { // set the deadline @@ -110,6 +129,18 @@ func reportFailure(invoke *invoke, invokeErr *messages.InvokeResponse_Error) err return nil } +func callSetupFunc(f func() error) (setupErr *messages.InvokeResponse_Error) { + defer func() { + if err := recover(); err != nil { + setupErr = lambdaPanicResponse(err) + } + }() + if err := f(); err != nil { + return lambdaErrorResponse(err) + } + return nil +} + func callBytesHandlerFunc(ctx context.Context, payload []byte, handler handlerFunc) (response io.Reader, invokeErr *messages.InvokeResponse_Error) { defer func() { if err := recover(); err != nil { diff --git a/lambda/rpc_function.go b/lambda/rpc_function.go index 0c8e798e..6f42cea5 100644 --- a/lambda/rpc_function.go +++ b/lambda/rpc_function.go @@ -33,11 +33,19 @@ func init() { } func startFunctionRPC(port string, handler Handler) error { + rpcFunction := NewFunction(handler) + if len(rpcFunction.handler.setupFuncs) > 0 { + runtimeAPIClient := newRuntimeAPIClient(os.Getenv("AWS_LAMBDA_RUNTIME_API")) + if err := handleSetup(runtimeAPIClient, rpcFunction.handler); err != nil { + return err + } + } + lis, err := net.Listen("tcp", "localhost:"+port) if err != nil { log.Fatal(err) } - err = rpc.Register(NewFunction(handler)) + err = rpc.Register(rpcFunction) if err != nil { log.Fatal("failed to register handler function") } diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go index a83c3ce8..3d13df89 100644 --- a/lambda/runtime_api_client.go +++ b/lambda/runtime_api_client.go @@ -37,11 +37,18 @@ func newRuntimeAPIClient(address string) *runtimeAPIClient { client := &http.Client{ Timeout: 0, // connections to the runtime API are never expected to time out } - endpoint := "http://" + address + "/" + apiVersion + "/runtime/invocation/" + endpoint := "http://" + address + "/" + apiVersion + "/runtime" userAgent := "aws-lambda-go/" + runtime.Version() return &runtimeAPIClient{endpoint, userAgent, client, bytes.NewBuffer(nil)} } +// initError connects to the Runtime API and reports that a failure occured during initialization. +// Note: After calling this function, the caller should call os.Exit() +func (c *runtimeAPIClient) initError(body io.Reader, contentType string) error { + url := c.baseURL + "/init/error" + return c.post(url, body, contentType) +} + type invoke struct { id string payload []byte @@ -53,7 +60,7 @@ type invoke struct { // Notes: // - An invoke is not complete until next() is called again! func (i *invoke) success(body io.Reader, contentType string) error { - url := i.client.baseURL + i.id + "/response" + url := i.client.baseURL + "/invocation/" + i.id + "/response" return i.client.post(url, body, contentType) } @@ -63,14 +70,14 @@ func (i *invoke) success(body io.Reader, contentType string) error { // - A Lambda Function continues to be re-used for future invokes even after a failure. // If the error is fatal (panic, unrecoverable state), exit the process immediately after calling failure() func (i *invoke) failure(body io.Reader, contentType string) error { - url := i.client.baseURL + i.id + "/error" + url := i.client.baseURL + "/invocation/" + i.id + "/error" return i.client.post(url, body, contentType) } // next connects to the Runtime API and waits for a new invoke Request to be available. // Note: After a call to Done() or Error() has been made, a call to next() will complete the in-flight invoke. func (c *runtimeAPIClient) next() (*invoke, error) { - url := c.baseURL + "next" + url := c.baseURL + "/invocation/next" req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to construct GET request to %s: %v", url, err)