Skip to content

Commit

Permalink
Use grpc::Alarm for unary coroutine lock instead of concurrent_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
4eUeP committed Sep 1, 2023
1 parent 627ccfc commit e1975dc
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 52 deletions.
12 changes: 6 additions & 6 deletions example/app/simple-server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

module Main where

import Control.Concurrent (threadDelay)
import Data.Either (isRight)
import Data.ProtoLens (defMessage)
import qualified Data.Text as Text
import Control.Concurrent (threadDelay)
import Data.Either (isRight)
import Data.ProtoLens (defMessage)
import qualified Data.Text as Text
import Lens.Micro

import HsGrpc.Common.Log
import HsGrpc.Server
import HsGrpc.Server.Context
import Proto.Example as P
import Proto.Example_Fields as P
import Proto.Example as P
import Proto.Example_Fields as P

handlers :: [ServiceHandler]
handlers =
Expand Down
43 changes: 33 additions & 10 deletions hs-grpc-server/HsGrpc/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import Data.Word (Word8)
import Foreign.ForeignPtr (ForeignPtr, newForeignPtr,
withForeignPtr)
import Foreign.Marshal.Utils (fromBool)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Ptr (Ptr, nullFunPtr, nullPtr)
import Foreign.Storable (peek, poke)
import GHC.TypeLits (Symbol, symbolVal)
import qualified HsForeign as HF
Expand Down Expand Up @@ -97,7 +97,10 @@ runServer ServerOptions{..} handlers = do
server <- newAsioServer
serverHost serverPort serverParallelism
serverSslOptions serverInterceptors
runAsioGrpc server handlers serverOnStarted serverInternalChannelSize
runAsioGrpc server
handlers
serverOnStarted
serverInternalChannelSize serverMaxUnaryTime

-------------------------------------------------------------------------------

Expand Down Expand Up @@ -130,8 +133,10 @@ runAsioGrpc
-> Word
-- ^ This is the buffer size of 'CppChannelIn' or 'CppChannelOut' used for
-- streaming rpcs.
-> Int
-- ^ MaxUnaryTime
-> IO ()
runAsioGrpc server handlers onStarted maxBufferSize =
runAsioGrpc server handlers onStarted maxBufferSize maxUnaryTime =
withForeignPtr server $ \server_ptr ->
-- handlers info
HF.withByteStringList (map rpcMethod handlers) $ \ms' ms_len' total_len ->
Expand All @@ -147,6 +152,7 @@ runAsioGrpc server handlers onStarted maxBufferSize =
cbPtr
cfdOnStarted
(fromIntegral maxBufferSize)
maxUnaryTime
stop a = shutdown_asio_server server_ptr >> Async.wait a
in Ex.bracket (Async.async start) stop Async.wait

Expand Down Expand Up @@ -308,11 +314,21 @@ processorCallback handlers request_ptr response_ptr = do
-- the cpp side already makes the bound check
let handler = handlers !! requestHandlerIdx req -- TODO: use vector to gain O(1) access
case handler of
UnaryHandler hd -> void $ forkIO $ unaryCallback req hd response_ptr
ShortUnaryHandler hd -> shortUnaryCallback req hd response_ptr
ClientStreamHandler hd -> void $ forkIO $ clientStreamCallback req hd response_ptr
ServerStreamHandler hd -> void $ forkIO $ serverStreamCallback req hd response_ptr
BidiStreamHandler hd -> void $ forkIO $ bidiStreamCallback req hd response_ptr
UnaryHandler hd -> do
tid <- forkIO $ unaryCallback req hd response_ptr
mkIOFun $ Ex.uninterruptibleMask_ (Ex.throwTo tid ActiveCancelled) -- free at cpp side
ShortUnaryHandler hd -> do
shortUnaryCallback req hd response_ptr
pure nullFunPtr
ClientStreamHandler hd -> do
void $ forkIO $ clientStreamCallback req hd response_ptr
pure nullFunPtr
ServerStreamHandler hd -> do
void $ forkIO $ serverStreamCallback req hd response_ptr
pure nullFunPtr
BidiStreamHandler hd -> do
void $ forkIO $ bidiStreamCallback req hd response_ptr
pure nullFunPtr

unaryCallback
:: (Message i, Message o)
Expand All @@ -327,8 +343,7 @@ unaryCallback Request{..} hd response_ptr =
Right requestMsg -> do
replyBs <- encodeMessage <$> hd requestServerContext requestMsg
poke response_ptr defResponse{responseData = Just replyBs}
clean = do
void $ releaseCoroLock requestCoroLock
clean = void $ releaseCoroLock requestCoroLock

shortUnaryCallback
:: (Message i, Message o)
Expand Down Expand Up @@ -393,6 +408,13 @@ bidiStreamCallback req hd response_ptr =

-------------------------------------------------------------------------------

data ActiveCancelled = ActiveCancelled
deriving (Show, Eq)

instance Ex.Exception ActiveCancelled where
fromException = Ex.asyncExceptionFromException
toException = Ex.asyncExceptionToException

catchGrpcError :: Ptr Response -> IO () -> IO ()
catchGrpcError ptr = void . catchGrpcError' ptr (pure ())

Expand All @@ -403,6 +425,7 @@ catchGrpcError' ptr sequel action =
errReply ptr status
_ <- sequel
pure Nothing
, Ex.Handler $ \(_ :: ActiveCancelled) -> pure Nothing
-- NOTE: SomeException should be the last Handler
, Ex.Handler $ \(ex :: Ex.SomeException) -> do
someExReply ptr ex
Expand Down
1 change: 1 addition & 0 deletions hs-grpc-server/HsGrpc/Server/FFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ foreign import ccall safe "run_asio_server"
-> FunPtr ProcessorCallback
-> CInt -- ^ fd onStarted
-> CSize -- ^ Max buffer size for the internal streaming channel
-> Int -- ^ Max time of unary
-> IO ()
23 changes: 13 additions & 10 deletions hs-grpc-server/HsGrpc/Server/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ module HsGrpc.Server.Types
, ProcessorCallback
, mkProcessorCallback
, withProcessorCallback
, mkIOFun
) where

import Control.Exception (Exception, bracket, throwIO)
Expand All @@ -85,9 +86,7 @@ import Data.Word (Word64, Word8)
import Foreign.Marshal.Alloc (allocaBytesAligned)
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr,
nullPtr)
import Foreign.StablePtr (StablePtr)
import Foreign.Storable (Storable (..))
import GHC.Conc (PrimMVar)
import qualified HsForeign as HF

import HsGrpc.Common.Foreign.Channel
Expand All @@ -105,7 +104,10 @@ data ServerOptions = ServerOptions
, serverOnStarted :: !(Maybe (IO ()))
, serverInterceptors :: ![ServerInterceptor]
-- The following options are considering as internal
, serverInternalChannelSize :: !Word
, serverInternalChannelSize :: {-# UNPACK #-} !Word
, serverMaxUnaryTime :: {-# UNPACK #-} !Int
-- ^ Milliseconds, unary that take more than this time will return a
-- StatusDeadlineExceeded to the client.
}

defaultServerOpts :: ServerOptions
Expand All @@ -117,6 +119,7 @@ defaultServerOpts = ServerOptions
, serverOnStarted = Nothing
, serverInterceptors = []
, serverInternalChannelSize = 2
, serverMaxUnaryTime = 30 * 1000
}

instance Show ServerOptions where
Expand All @@ -135,7 +138,12 @@ instance Exception ServerException

-------------------------------------------------------------------------------

type ProcessorCallback = Ptr Request -> Ptr Response -> IO ()
type ProcessorCallback
= Ptr Request
-> Ptr Response
-> IO (FunPtr (IO ())) -- ^ An optional function that let C++ to call

foreign import ccall "wrapper" mkIOFun :: IO () -> IO (FunPtr (IO ()))

foreign import ccall "wrapper"
mkProcessorCallback :: ProcessorCallback -> IO (FunPtr ProcessorCallback)
Expand All @@ -150,13 +158,8 @@ withProcessorCallback cb = bracket (mkProcessorCallback cb) freeHaskellFunPtr
newtype CoroLock = CoroLock (Ptr ())
deriving (Show)

-- CoroLock should never be NULL. However, I recheck it inside the c function.
releaseCoroLock :: CoroLock -> IO Int
releaseCoroLock lock = HF.withPrimAsyncFFI @Int (release_corolock lock)

foreign import ccall unsafe "release_corolock"
release_corolock
:: CoroLock -> StablePtr PrimMVar -> Int -> Ptr Int -> IO ()
releaseCoroLock :: CoroLock -> IO ()

data Request = Request
{ requestPayload :: ByteString
Expand Down
44 changes: 23 additions & 21 deletions hs-grpc-server/cbits/hs_grpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ struct HsAsioHandler {
HsCallback& callback;
asio::thread_pool& thread_pool;
size_t max_buffer_size;
HsInt max_unary_time;

asio::awaitable<void>
handleUnary(grpc::GenericServerContext& server_context,
Expand Down Expand Up @@ -215,15 +216,21 @@ struct HsAsioHandler {
// Call haskell handler
(*callback)(&request, &response);
} else {
// FIXME: use a lightweight structure instead (a real coroutine lock)
auto coro_lock = CoroLock(co_await asio::this_coro::executor, 1);
grpc::Alarm coro_lock;
request.coro_lock = &coro_lock;

// Call haskell handler
(*callback)(&request, &response);

const auto [ec, _] =
co_await coro_lock.async_receive(asio::as_tuple(asio::use_awaitable));
auto after_cb = (*callback)(&request, &response);
bool deadline_exceeded = co_await agrpc::wait(
coro_lock, std::chrono::system_clock::now() +
std::chrono::milliseconds(max_unary_time));
if (deadline_exceeded) {
if (after_cb) {
(*after_cb)();
}
co_await agrpc::finish(reader_writer,
grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED,
"Unary deadline exceeded!"));
co_return;
}
}

// Return to client
Expand Down Expand Up @@ -558,7 +565,7 @@ void run_asio_server(CppAsioServer* server,
HsInt method_handlers_total_len,
// method handlers end
hsgrpc::HsCallback callback, int fd_on_started,
size_t max_buffer_size) {
size_t max_buffer_size, HsInt max_unary_time) {
server->method_handlers_.reserve(method_handlers_total_len);
for (HsInt i = 0; i < method_handlers_total_len; ++i) {
server->method_handlers_.emplace(std::make_pair(
Expand All @@ -576,10 +583,10 @@ void run_asio_server(CppAsioServer* server,
auto& grpc_context = *std::next(server->grpc_contexts_.begin(), i);
agrpc::repeatedly_request(
server->service_,
asio::bind_executor(grpc_context,
hsgrpc::HsAsioHandler{server->method_handlers_,
callback, thread_pool,
max_buffer_size}));
asio::bind_executor(
grpc_context, hsgrpc::HsAsioHandler{
server->method_handlers_, callback, thread_pool,
max_buffer_size, max_unary_time}));
grpc_context.run();
});
}
Expand Down Expand Up @@ -663,14 +670,9 @@ void delete_out_channel(hsgrpc::channel_out_t* channel) {
delete channel;
}

void release_corolock(hsgrpc::CoroLock* channel, HsStablePtr mvar, HsInt cap,
HsInt* ret_code) {
if (channel) {
channel->async_send(asio::error_code{}, true,
[cap, mvar, ret_code](asio::error_code ec) {
*ret_code = (HsInt)ec.value();
hs_try_putmvar(cap, mvar);
});
void release_corolock(hsgrpc::CoroLock* corolock) {
if (corolock) {
corolock->Cancel();
}
}

Expand Down
9 changes: 4 additions & 5 deletions hs-grpc-server/include/hs_grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <asio/experimental/concurrent_channel.hpp>
#include <cstdint>
#include <grpcpp/alarm.h>
#include <grpcpp/server.h>
#include <grpcpp/support/slice.h>

Expand All @@ -16,10 +17,7 @@ using ChannelOut = asio::experimental::concurrent_channel<void(
asio::error_code, grpc::ByteBuffer)>;

// FIXME: use a lightweight structure instead (a real coroutine lock)
//
// Using bool for convenience, this can be any type actually.
using CoroLock =
asio::experimental::concurrent_channel<void(asio::error_code, bool)>;
using CoroLock = grpc::Alarm;

struct channel_in_t {
std::shared_ptr<ChannelIn> rep;
Expand Down Expand Up @@ -70,7 +68,8 @@ struct server_response_t {
std::string* error_details = nullptr;
};

using HsCallback = void (*)(server_request_t*, server_response_t*);
using HsAfterCallback = void (*)(void);
using HsCallback = HsAfterCallback (*)(server_request_t*, server_response_t*);

struct read_channel_cb_data_t {
HsInt ec;
Expand Down

0 comments on commit e1975dc

Please sign in to comment.