diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift index a23ef1cf..a3d98aa2 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift @@ -19,13 +19,22 @@ import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix +import Synchronization // This functionality is designed for local testing hence being a #if DEBUG flag. + // For example: -// // try Lambda.withLocalServer { -// Lambda.run { (context: LambdaContext, event: String, callback: @escaping (Result) -> Void) in -// callback(.success("Hello, \(event)!")) +// try await LambdaRuntimeClient.withRuntimeClient( +// configuration: .init(ip: "127.0.0.1", port: 7000), +// eventLoop: self.eventLoop, +// logger: self.logger +// ) { runtimeClient in +// try await Lambda.runLoop( +// runtimeClient: runtimeClient, +// handler: handler, +// logger: self.logger +// ) // } // } extension Lambda { @@ -36,290 +45,383 @@ extension Lambda { /// - body: Code to run within the context of the mock server. Typically this would be a Lambda.run function call. /// /// - note: This API is designed strictly for local testing and is behind a DEBUG flag - static func withLocalServer( + static func withLocalServer( invocationEndpoint: String? = nil, - _ body: @escaping () async throws -> Value - ) async throws -> Value { - let server = LocalLambda.Server(invocationEndpoint: invocationEndpoint) - try await server.start().get() - defer { try! server.stop() } - return try await body() + _ body: @escaping () async throws -> Void + ) async throws { + + // launch the local server and wait for it to be started before running the body + try await withThrowingTaskGroup(of: Void.self) { group in + // this call will return when the server calls continuation.resume() + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + group.addTask { + try await LambdaHttpServer(invocationEndpoint: invocationEndpoint).start(continuation: continuation) + } + } + // now that server is started, run the Lambda function itself + try await body() + } } } -// MARK: - Local Mock Server - -private enum LocalLambda { - struct Server { - private let logger: Logger - private let group: EventLoopGroup - private let host: String - private let port: Int - private let invocationEndpoint: String - - init(invocationEndpoint: String?) { - var logger = Logger(label: "LocalLambdaServer") - logger.logLevel = .info - self.logger = logger - self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - self.host = "127.0.0.1" - self.port = 7000 - self.invocationEndpoint = invocationEndpoint ?? "/invoke" - } +// MARK: - Local HTTP Server + +/// An HTTP server that behaves like the AWS Lambda service for local testing. +/// This server is used to simulate the AWS Lambda service for local testing but also to accept invocation requests from the lambda client. +/// +/// It accepts three types of requests from the Lambda function (through the LambdaRuntimeClient): +/// 1. GET /next - the lambda function polls this endpoint to get the next invocation request +/// 2. POST /:requestID/response - the lambda function posts the response to the invocation request +/// 3. POST /:requestID/error - the lambda function posts an error response to the invocation request +/// +/// It also accepts one type of request from the client invoking the lambda function: +/// 1. POST /invoke - the client posts the event to the lambda function +/// +/// This server passes the data received from /invoke POST request to the lambda function (GET /next) and then forwards the response back to the client. +private struct LambdaHttpServer { + private let logger: Logger + private let group: EventLoopGroup + private let host: String + private let port: Int + private let invocationEndpoint: String + + private let invocationPool = Pool() + private let responsePool = Pool() + + init(invocationEndpoint: String?) { + var logger = Logger(label: "LocalServer") + logger.logLevel = Lambda.env("LOG_LEVEL").flatMap(Logger.Level.init) ?? .info + self.logger = logger + self.group = MultiThreadedEventLoopGroup.singleton + self.host = "127.0.0.1" + self.port = 7000 + self.invocationEndpoint = invocationEndpoint ?? "/invoke" + } - func start() -> EventLoopFuture { - let bootstrap = ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in - channel.pipeline.addHandler( - HTTPHandler(logger: self.logger, invocationEndpoint: self.invocationEndpoint) + func start(continuation: CheckedContinuation) async throws { + let channel = try await ServerBootstrap(group: self.group) + .serverChannelOption(.backlog, value: 256) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(.maxMessagesPerRead, value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withErrorHandling: true + ) + + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: NIOAsyncChannel.Configuration( + inboundType: HTTPServerRequestPart.self, + outboundType: HTTPServerResponsePart.self ) - } - } - return bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture in - guard channel.localAddress != nil else { - return channel.eventLoop.makeFailedFuture(ServerError.cantBind) + ) } - self.logger.info( - "LocalLambdaServer started and listening on \(self.host):\(self.port), receiving events on \(self.invocationEndpoint)" - ) - return channel.eventLoop.makeSucceededFuture(()) } - } - func stop() throws { - try self.group.syncShutdownGracefully() + // notify the caller that the server is started + continuation.resume() + logger.info( + "Server started and listening", + metadata: [ + "host": "\(channel.channel.localAddress?.ipAddress?.debugDescription ?? "")", + "port": "\(channel.channel.localAddress?.port ?? 0)", + ] + ) + + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + + group.addTask { + logger.trace("Handling a new connection") + await self.handleConnection(channel: connectionChannel) + logger.trace("Done handling the connection") + } + } + } } + logger.info("Server shutting down") } - final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() - - private static var invocations = CircularBuffer() - private static var invocationState = InvocationState.waitingForLambdaRequest + /// This method handles individual TCP connections + private func handleConnection( + channel: NIOAsyncChannel + ) async { + + var requestHead: HTTPRequestHead! + var requestBody: ByteBuffer? + + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + if case .head(let head) = inboundData { + requestHead = head + } + if case .body(let body) = inboundData { + requestBody = body + } + if case .end = inboundData { + precondition(requestHead != nil, "Received .end without .head") + // process the request + let response = try await self.processRequest( + head: requestHead, + body: requestBody + ) + // send the responses + try await self.sendResponse( + response: response, + outbound: outbound + ) - private let logger: Logger - private let invocationEndpoint: String + requestHead = nil + requestBody = nil + } + } + } + } catch { + logger.error("Hit error: \(error)") + } + } - init(logger: Logger, invocationEndpoint: String) { - self.logger = logger - self.invocationEndpoint = invocationEndpoint + /// This function process the URI request sent by the client and by the Lambda function + /// + /// It enqueues the client invocation and iterate over the invocation queue when the Lambda function sends /next request + /// It answers the /:requestID/response and /:requestID/error requests sent by the Lambda function but do not process the body + /// + /// - Parameters: + /// - head: the HTTP request head + /// - body: the HTTP request body + /// - Throws: + /// - Returns: the response to send back to the client or the Lambda function + private func processRequest(head: HTTPRequestHead, body: ByteBuffer?) async throws -> LocalServerResponse { + + if let body { + self.logger.trace( + "Processing request", + metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"] + ) + } else { + self.logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) } - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let requestPart = unwrapInboundIn(data) + switch (head.method, head.uri) { - switch requestPart { - case .head(let head): - self.pending.append((head: head, body: nil)) - case .body(var buffer): - var request = self.pending.removeFirst() - if request.body == nil { - request.body = buffer + // + // client invocations + // + // client POST /invoke + case (.POST, let url) where url.hasSuffix(self.invocationEndpoint): + guard let body else { + return .init(status: .badRequest, headers: [], body: nil) + } + // we always accept the /invoke request and push them to the pool + let requestId = "\(DispatchTime.now().uptimeNanoseconds)" + logger.trace("/invoke received invocation", metadata: ["requestId": "\(requestId)"]) + await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) + + // wait for the lambda function to process the request + for try await response in self.responsePool { + logger.trace( + "Received response to return to client", + metadata: ["requestId": "\(response.requestId ?? "")"] + ) + if response.requestId == requestId { + return response } else { - request.body!.writeBuffer(&buffer) + logger.error( + "Received response for a different request id", + metadata: ["response requestId": "\(response.requestId ?? "")", "requestId": "\(requestId)"] + ) + // should we return an error here ? Or crash as this is probably a programming error? } - self.pending.prepend(request) - case .end: - let request = self.pending.removeFirst() - self.processRequest(context: context, request: request) } - } - - func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { - - let eventLoop = context.eventLoop - let loopBoundContext = NIOLoopBound(context, eventLoop: eventLoop) + // What todo when there is no more responses to process? + // This should not happen as the async iterator blocks until there is a response to process + fatalError("No more responses to process - the async for loop should not return") + + // client uses incorrect HTTP method + case (_, let url) where url.hasSuffix(self.invocationEndpoint): + return .init(status: .methodNotAllowed) + + // + // lambda invocations + // + + // /next endpoint is called by the lambda polling for work + // this call only returns when there is a task to give to the lambda function + case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): + + // pop the tasks from the queue + self.logger.trace("/next waiting for /invoke") + for try await invocation in self.invocationPool { + self.logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) + // this call also stores the invocation requestId into the response + return invocation.makeResponse(status: .accepted) + } + // What todo when there is no more tasks to process? + // This should not happen as the async iterator blocks until there is a task to process + fatalError("No more invocations to process - the async for loop should not return") + + // :requestID/response endpoint is called by the lambda posting the response + case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): + let parts = head.uri.split(separator: "/") + guard let requestId = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + // the request is malformed, since we were expecting a requestId in the path + return .init(status: .badRequest) + } + // enqueue the lambda function response to be served as response to the client /invoke + logger.trace("/:requestID/response received response", metadata: ["requestId": "\(requestId)"]) + await self.responsePool.push( + LocalServerResponse( + id: requestId, + status: .ok, + headers: [("Content-Type", "application/json")], + body: body + ) + ) - switch (request.head.method, request.head.uri) { - // this endpoint is called by the client invoking the lambda - case (.POST, let url) where url.hasSuffix(self.invocationEndpoint): - guard let work = request.body else { - return self.writeResponse(context: context, response: .init(status: .badRequest)) - } - let requestID = "\(DispatchTime.now().uptimeNanoseconds)" // FIXME: - let promise = context.eventLoop.makePromise(of: Response.self) - promise.futureResult.whenComplete { result in - let context = loopBoundContext.value - switch result { - case .failure(let error): - self.logger.error("invocation error: \(error)") - self.writeResponse(context: context, response: .init(status: .internalServerError)) - case .success(let response): - self.writeResponse(context: context, response: response) - } - } - let invocation = Invocation(requestID: requestID, request: work, responsePromise: promise) - switch Self.invocationState { - case .waitingForInvocation(let promise): - promise.succeed(invocation) - case .waitingForLambdaRequest, .waitingForLambdaResponse: - Self.invocations.append(invocation) - } + // tell the Lambda function we accepted the response + return .init(id: requestId, status: .accepted) - // lambda invocation using the wrong http method - case (_, let url) where url.hasSuffix(self.invocationEndpoint): - self.writeResponse(context: context, status: .methodNotAllowed) - - // /next endpoint is called by the lambda polling for work - case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): - // check if our server is in the correct state - guard case .waitingForLambdaRequest = Self.invocationState else { - self.logger.error("invalid invocation state \(Self.invocationState)") - self.writeResponse(context: context, response: .init(status: .unprocessableEntity)) - return - } + // :requestID/error endpoint is called by the lambda posting an error response + // we accept all requestID and we do not handle the body, we just acknowledge the request + case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): + let parts = head.uri.split(separator: "/") + guard let _ = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + // the request is malformed, since we were expecting a requestId in the path + return .init(status: .badRequest) + } + return .init(status: .ok) - // pop the first task from the queue - switch Self.invocations.popFirst() { - case .none: - // if there is nothing in the queue, - // create a promise that we can fullfill when we get a new task - let promise = context.eventLoop.makePromise(of: Invocation.self) - promise.futureResult.whenComplete { result in - let context = loopBoundContext.value - switch result { - case .failure(let error): - self.logger.error("invocation error: \(error)") - self.writeResponse(context: context, status: .internalServerError) - case .success(let invocation): - Self.invocationState = .waitingForLambdaResponse(invocation) - self.writeResponse(context: context, response: invocation.makeResponse()) - } - } - Self.invocationState = .waitingForInvocation(promise) - case .some(let invocation): - // if there is a task pending, we can immediately respond with it. - Self.invocationState = .waitingForLambdaResponse(invocation) - self.writeResponse(context: context, response: invocation.makeResponse()) - } + // unknown call + default: + return .init(status: .notFound) + } + } - // :requestID/response endpoint is called by the lambda posting the response - case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): - let parts = request.head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { - // the request is malformed, since we were expecting a requestId in the path - return self.writeResponse(context: context, status: .badRequest) - } - guard case .waitingForLambdaResponse(let invocation) = Self.invocationState else { - // a response was send, but we did not expect to receive one - self.logger.error("invalid invocation state \(Self.invocationState)") - return self.writeResponse(context: context, status: .unprocessableEntity) - } - guard requestID == invocation.requestID else { - // the request's requestId is not matching the one we are expecting - self.logger.error( - "invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)" - ) - return self.writeResponse(context: context, status: .badRequest) - } + private func sendResponse( + response: LocalServerResponse, + outbound: NIOAsyncChannelOutboundWriter + ) async throws { + var headers = HTTPHeaders(response.headers ?? []) + headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)") + + self.logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) + try await outbound.write( + HTTPServerResponsePart.head( + HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: response.status, + headers: headers + ) + ) + ) + if let body = response.body { + try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(body))) + } - invocation.responsePromise.succeed(.init(status: .ok, body: request.body)) - self.writeResponse(context: context, status: .accepted) - Self.invocationState = .waitingForLambdaRequest + try await outbound.write(HTTPServerResponsePart.end(nil)) + } - // :requestID/error endpoint is called by the lambda posting an error response - case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): - let parts = request.head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { - // the request is malformed, since we were expecting a requestId in the path - return self.writeResponse(context: context, status: .badRequest) - } - guard case .waitingForLambdaResponse(let invocation) = Self.invocationState else { - // a response was send, but we did not expect to receive one - self.logger.error("invalid invocation state \(Self.invocationState)") - return self.writeResponse(context: context, status: .unprocessableEntity) - } - guard requestID == invocation.requestID else { - // the request's requestId is not matching the one we are expecting - self.logger.error( - "invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)" - ) - return self.writeResponse(context: context, status: .badRequest) - } + /// A shared data structure to store the current invocation or response requests and the continuation objects. + /// This data structure is shared between instances of the HTTPHandler + /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function). + private final class Pool: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { + typealias Element = T - invocation.responsePromise.succeed(.init(status: .internalServerError, body: request.body)) - self.writeResponse(context: context, status: .accepted) - Self.invocationState = .waitingForLambdaRequest + private let _buffer = Mutex>(.init()) + private let _continuation = Mutex?>(nil) - // unknown call - default: - self.writeResponse(context: context, status: .notFound) - } + /// retrieve the first element from the buffer + public func popFirst() async -> T? { + self._buffer.withLock { $0.popFirst() } } - func writeResponse(context: ChannelHandlerContext, status: HTTPResponseStatus) { - self.writeResponse(context: context, response: .init(status: status)) + /// enqueue an element, or give it back immediately to the iterator if it is waiting for an element + public func push(_ invocation: T) async { + // if the iterator is waiting for an element, give it to it + // otherwise, enqueue the element + if let continuation = self._continuation.withLock({ $0 }) { + self._continuation.withLock { $0 = nil } + continuation.resume(returning: invocation) + } else { + self._buffer.withLock { $0.append(invocation) } + } } - func writeResponse(context: ChannelHandlerContext, response: Response) { - var headers = HTTPHeaders(response.headers ?? []) - headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)") - let head = HTTPResponseHead( - version: HTTPVersion(major: 1, minor: 1), - status: response.status, - headers: headers - ) + func next() async throws -> T? { - context.write(wrapOutboundOut(.head(head))).whenFailure { error in - self.logger.error("\(self) write error \(error)") - } - - if let buffer = response.body { - context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in - self.logger.error("\(self) write error \(error)") - } + // exit the async for loop if the task is cancelled + guard !Task.isCancelled else { + return nil } - context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in - if case .failure(let error) = result { - self.logger.error("\(self) write error \(error)") + if let element = await self.popFirst() { + return element + } else { + // we can't return nil if there is nothing to dequeue otherwise the async for loop will stop + // wait for an element to be enqueued + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + // store the continuation for later, when an element is enqueued + self._continuation.withLock { + $0 = continuation + } } } } - struct Response { - var status: HTTPResponseStatus = .ok - var headers: [(String, String)]? - var body: ByteBuffer? - } - - struct Invocation { - let requestID: String - let request: ByteBuffer - let responsePromise: EventLoopPromise - - func makeResponse() -> Response { - var response = Response() - response.body = self.request - // required headers - response.headers = [ - (AmazonHeaders.requestID, self.requestID), - ( - AmazonHeaders.invokedFunctionARN, - "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime" - ), - (AmazonHeaders.traceID, "Root=\(AmazonHeaders.generateXRayTraceID());Sampled=1"), - (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), - ] - return response - } + func makeAsyncIterator() -> Pool { + self } + } - enum InvocationState { - case waitingForInvocation(EventLoopPromise) - case waitingForLambdaRequest - case waitingForLambdaResponse(Invocation) + private struct LocalServerResponse: Sendable { + let requestId: String? + let status: HTTPResponseStatus + let headers: [(String, String)]? + let body: ByteBuffer? + init(id: String? = nil, status: HTTPResponseStatus, headers: [(String, String)]? = nil, body: ByteBuffer? = nil) + { + self.requestId = id + self.status = status + self.headers = headers + self.body = body } } - enum ServerError: Error { - case notReady - case cantBind + private struct LocalServerInvocation: Sendable { + let requestId: String + let request: ByteBuffer + + func makeResponse(status: HTTPResponseStatus) -> LocalServerResponse { + + // required headers + let headers = [ + (AmazonHeaders.requestID, self.requestId), + ( + AmazonHeaders.invokedFunctionARN, + "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime" + ), + (AmazonHeaders.traceID, "Root=\(AmazonHeaders.generateXRayTraceID());Sampled=1"), + (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), + ] + + return LocalServerResponse(id: self.requestId, status: status, headers: headers, body: self.request) + } } } #endif