Skip to content

Commit

Permalink
Let haskell unary handler in coroutine ecosystem
Browse files Browse the repository at this point in the history
  • Loading branch information
4eUeP committed Aug 31, 2023
1 parent 05d9084 commit 3b2be4b
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 12 deletions.
48 changes: 44 additions & 4 deletions hs-grpc-server/HsGrpc/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ module HsGrpc.Server
, BidiStreamHandler
, ServiceHandler
, unary
, shortUnary
, clientStream
, serverStream
, bidiStream
Expand Down Expand Up @@ -141,12 +142,13 @@ runAsioGrpc server handlers onStarted maxBufferSize =
HF.withByteStringList (map rpcMethod handlers) $ \ms' ms_len' total_len ->
HF.withPrimList (map (handlerCStreamingType . rpcHandler) handlers) $ \mt' _mt_len ->
HF.withPrimList (map (fromBool . rpcUseThreadPool) handlers) $ \mUseThread' _ ->
HF.withPrimList (map (fromBool . isShortUnary . rpcHandler) handlers) $ \mIsShortUnary' _ ->
-- handlers callback
withProcessorCallback (processorCallback $ map rpcHandler handlers) $ \cbPtr -> do
evm <- getSystemEventManager' $ ServerException "failed to get event manager"
withFdEventNotification evm onStarted OneShot $ \(Fd cfdOnStarted) -> do
let start = run_asio_server server_ptr
ms' ms_len' mt' mUseThread' total_len
ms' ms_len' mt' mUseThread' mIsShortUnary' total_len
cbPtr
cfdOnStarted
(fromIntegral maxBufferSize)
Expand Down Expand Up @@ -202,6 +204,8 @@ type BidiStreamHandler i o a = ServerContext -> BidiStream i o -> IO a
data RpcHandler where
UnaryHandler
:: (Message i, Message o) => UnaryHandler i o -> RpcHandler
ShortUnaryHandler
:: (Message i, Message o) => UnaryHandler i o -> RpcHandler
ClientStreamHandler
:: (Message i, Message o) => ClientStreamHandler i o -> RpcHandler
ServerStreamHandler
Expand All @@ -211,12 +215,18 @@ data RpcHandler where

instance Show RpcHandler where
show (UnaryHandler _) = "<UnaryHandler>"
show (ShortUnaryHandler _) = "<ShortUnaryHandler>"
show (ClientStreamHandler _) = "<ClientStreamHandler>"
show (ServerStreamHandler _) = "<ServerStreamHandler>"
show (BidiStreamHandler _) = "<BidiStreamHandler>"

isShortUnary :: RpcHandler -> Bool
isShortUnary (ShortUnaryHandler _) = True
isShortUnary _ = False

handlerCStreamingType :: RpcHandler -> Word8
handlerCStreamingType (UnaryHandler _) = C_StreamingType_NonStreaming
handlerCStreamingType (ShortUnaryHandler _) = C_StreamingType_NonStreaming
handlerCStreamingType (ClientStreamHandler _) = C_StreamingType_ClientStreaming
handlerCStreamingType (ServerStreamHandler _) = C_StreamingType_ServerStreaming
handlerCStreamingType (BidiStreamHandler _) = C_StreamingType_BiDiStreaming
Expand Down Expand Up @@ -246,6 +256,20 @@ unary grpc handler =
, rpcUseThreadPool = False
}

shortUnary
:: ( HasMethod s m, Message i, Message o
, MethodInput s m ~ i
, MethodOutput s m ~ o
)
=> GRPC s m
-> UnaryHandler i o
-> ServiceHandler
shortUnary grpc handler =
ServiceHandler{ rpcMethod = getGrpcMethod grpc
, rpcHandler = ShortUnaryHandler handler
, rpcUseThreadPool = False
}

clientStream
:: ( HasMethod s m, Message i, Message o
, MethodInput s m ~ i
Expand Down Expand Up @@ -296,22 +320,38 @@ processorCallback handlers request_ptr response_ptr = do
-- the cpp side already makes the bound check
let handler = handlers !! requestHandlerIdx req -- TODO: use vector to gain O(1) access
case handler of
UnaryHandler hd -> unaryCallback req hd response_ptr
UnaryHandler hd -> void $ forkIO $ unaryCallback req hd response_ptr
ShortUnaryHandler hd -> shortUnaryCallback req hd response_ptr
ClientStreamHandler hd -> void $ forkIO $ clientStreamCallback req hd response_ptr
ServerStreamHandler hd -> void $ forkIO $ serverStreamCallback req hd response_ptr
BidiStreamHandler hd -> void $ forkIO $ bidiStreamCallback req hd response_ptr

unaryCallback
:: (Message i, Message o)
=> Request -> UnaryHandler i o -> Ptr Response -> IO ()
unaryCallback Request{..} hd response_ptr = catchGrpcError response_ptr $ do
unaryCallback Request{..} hd response_ptr =
void $ catchGrpcError' response_ptr clean action
where
action = do
let e_requestMsg = decodeMessage requestPayload
case e_requestMsg of
Left errmsg -> parsingReqErrReply response_ptr (BSC.pack errmsg)
Right requestMsg -> do
replyBs <- encodeMessage <$> hd requestServerContext requestMsg
poke response_ptr defResponse{responseData = Just replyBs}
clean = do
void $ releaseCoroLock requestCoroLock

shortUnaryCallback
:: (Message i, Message o)
=> Request -> UnaryHandler i o -> Ptr Response -> IO ()
shortUnaryCallback Request{..} hd response_ptr = catchGrpcError response_ptr $ do
let e_requestMsg = decodeMessage requestPayload
case e_requestMsg of
Left errmsg -> parsingReqErrReply response_ptr (BSC.pack errmsg)
Right requestMsg -> do
replyBs <- encodeMessage <$> hd requestServerContext requestMsg
poke response_ptr defResponse{responseData = Just replyBs}
{-# INLINABLE unaryCallback #-}

clientStreamCallback
:: (Message o)
Expand Down
2 changes: 2 additions & 0 deletions hs-grpc-server/HsGrpc/Server/FFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ foreign import ccall safe "run_asio_server"
-- ^ Value of method_handlers: StreamingType
-> Ptr CBool
-- ^ Value of method_handlers: use_thread_pool
-> Ptr CBool
-- ^ Value of method_handlers: is_short_unary
-> Int
-- ^ Total size of method_handlers
-> FunPtr ProcessorCallback
Expand Down
18 changes: 18 additions & 0 deletions hs-grpc-server/HsGrpc/Server/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ module HsGrpc.Server.Types
, Request (..)
, Response (..)
, defResponse
, CoroLock
, releaseCoroLock
-- ** StreamingType
, StreamingType (..)
, pattern C_StreamingType_NonStreaming
Expand All @@ -83,7 +85,9 @@ import Data.Word (Word64, Word8)
import Foreign.Marshal.Alloc (allocaBytesAligned)
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr,
nullPtr)
import Foreign.StablePtr (StablePtr)
import Foreign.Storable (Storable (..))
import GHC.Conc (PrimMVar)
import qualified HsForeign as HF

import HsGrpc.Common.Foreign.Channel
Expand Down Expand Up @@ -143,11 +147,23 @@ withProcessorCallback cb = bracket (mkProcessorCallback cb) freeHaskellFunPtr

-------------------------------------------------------------------------------

newtype CoroLock = CoroLock (Ptr ())
deriving (Show)

-- CoroLock should never be NULL. However, I recheck it inside the c function.
releaseCoroLock :: CoroLock -> IO Int
releaseCoroLock lock = HF.withPrimAsyncFFI @Int (release_corolock lock)

foreign import ccall unsafe "release_corolock"
release_corolock
:: CoroLock -> StablePtr PrimMVar -> Int -> Ptr Int -> IO ()

data Request = Request
{ requestPayload :: ByteString
, requestHandlerIdx :: Int
, requestReadChannel :: Maybe ChannelIn
, requestWriteChannel :: Maybe ChannelOut
, requestCoroLock :: CoroLock
, requestServerContext :: ServerContext
} deriving (Show)

Expand All @@ -170,11 +186,13 @@ instance Storable Request where
(#peek hsgrpc::server_request_t, channel_in) ptr
channelOut <- peekMaybeCppChannelOut =<<
(#peek hsgrpc::server_request_t, channel_out) ptr
coroLock <- (#peek hsgrpc::server_request_t, coro_lock) ptr
serverContext <- (#peek hsgrpc::server_request_t, server_context) ptr
return $ Request{ requestPayload = payload
, requestHandlerIdx = handleIdx
, requestReadChannel = channelIn
, requestWriteChannel = channelOut
, requestCoroLock = CoroLock coroLock
, requestServerContext = ServerContext serverContext
}
poke _ptr _req = error "Request is not pokeable"
Expand Down
47 changes: 39 additions & 8 deletions hs-grpc-server/cbits/hs_grpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ struct HandlerInfo {
// bool use_thread_pool : 1;
StreamingType type;
bool use_thread_pool;
bool is_short_unary;
};

struct HsAsioHandler {
Expand All @@ -169,7 +170,7 @@ struct HsAsioHandler {
handleUnary(grpc::GenericServerContext& server_context,
grpc::GenericServerAsyncReaderWriter& reader_writer,
server_request_t& request, server_response_t& response,
bool use_thread_pool) {
bool use_thread_pool, bool is_short_unary) {
// Wait for the request message
grpc::ByteBuffer buffer;
bool read_ok = co_await agrpc::read(reader_writer, buffer);
Expand All @@ -193,12 +194,24 @@ struct HsAsioHandler {
request.data_size = slice.size();
request.server_context = &server_context;

if (use_thread_pool) {
co_await asio::post(
asio::bind_executor(thread_pool, asio::use_awaitable));
if (is_short_unary) {
if (use_thread_pool) {
co_await asio::post(
asio::bind_executor(thread_pool, asio::use_awaitable));
}
// Call haskell handler
(*callback)(&request, &response);
} else {
// FIXME: use a lightweight structure instead (a real coroutine lock)
auto coro_lock = CoroLock(co_await asio::this_coro::executor, 1);
request.coro_lock = &coro_lock;

// Call haskell handler
(*callback)(&request, &response);

const auto [ec, _] =
co_await coro_lock.async_receive(asio::as_tuple(asio::use_awaitable));
}
// Call haskell handler
(*callback)(&request, &response);

// Return to client
auto status_code = static_cast<grpc::StatusCode>(response.status_code);
Expand Down Expand Up @@ -375,7 +388,8 @@ struct HsAsioHandler {
switch (method_handler_->second.type) {
case StreamingType::NonStreaming: {
co_await handleUnary(server_context, reader_writer, request, response,
method_handler_->second.use_thread_pool);
method_handler_->second.use_thread_pool,
method_handler_->second.is_short_unary);
break;
}
case StreamingType::BiDiStreaming: {
Expand Down Expand Up @@ -501,6 +515,7 @@ void run_asio_server(CppAsioServer* server,
char** method_handlers, HsInt* method_handlers_len,
uint8_t* method_handlers_type,
bool* method_handlers_use_thread_pool,
bool* method_handlers_is_short_unary,
HsInt method_handlers_total_len,
// method handlers end
hsgrpc::HsCallback callback, int fd_on_started,
Expand All @@ -510,7 +525,8 @@ void run_asio_server(CppAsioServer* server,
server->method_handlers_.emplace(std::make_pair(
std::string(method_handlers[i], method_handlers_len[i]),
hsgrpc::HandlerInfo{i, hsgrpc::StreamingType(method_handlers_type[i]),
method_handlers_use_thread_pool[i]}));
method_handlers_use_thread_pool[i],
method_handlers_is_short_unary[i]}));
}

auto parallelism = server->server_threads_.capacity();
Expand Down Expand Up @@ -543,6 +559,10 @@ void shutdown_asio_server(CppAsioServer* server) {

void delete_asio_server(CppAsioServer* server) {
gpr_log(GPR_DEBUG, "Delete allocated server");
// FIXME: The delete_asio_server function is invoked by the Haskell garbage
// collector. In the presence of a C++ thread(server->server_threads_) that
// hasn't been properly joined, an error of "terminate called without an
// active exception" can occur.
delete server;
#ifdef HSGRPC_ENABLE_ASAN
__lsan_do_leak_check();
Expand Down Expand Up @@ -604,5 +624,16 @@ void delete_out_channel(hsgrpc::channel_out_t* channel) {
delete channel;
}

void release_corolock(hsgrpc::CoroLock* channel, HsStablePtr mvar, HsInt cap,
HsInt* ret_code) {
if (channel) {
channel->async_send(asio::error_code{}, true,
[cap, mvar, ret_code](asio::error_code ec) {
*ret_code = (HsInt)ec.value();
hs_try_putmvar(cap, mvar);
});
}
}

// ----------------------------------------------------------------------------
} // End extern "C"
7 changes: 7 additions & 0 deletions hs-grpc-server/include/hs_grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ using ChannelIn =
using ChannelOut = asio::experimental::concurrent_channel<void(
asio::error_code, grpc::ByteBuffer)>;

// FIXME: use a lightweight structure instead (a real coroutine lock)
//
// Using bool for convenience, this can be any type actually.
using CoroLock =
asio::experimental::concurrent_channel<void(asio::error_code, bool)>;

struct channel_in_t {
std::shared_ptr<ChannelIn> rep;
};
Expand Down Expand Up @@ -52,6 +58,7 @@ struct server_request_t {
HsInt handler_idx;
channel_in_t* channel_in = nullptr;
channel_out_t* channel_out = nullptr;
CoroLock* coro_lock = nullptr;
grpc::GenericServerContext* server_context = nullptr;
};

Expand Down

0 comments on commit 3b2be4b

Please sign in to comment.