Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connection pool per node #40

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions benchmark/ClusterBenchmark.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import Data.Time
import Database.Redis hiding (append)
import Text.Printf
import qualified Data.ByteString.Char8 as BS
import System.Random (randomRIO)

nRequests, nClients :: Int
nRequests = 10000
nClients = 50
nRequests = 30000
nClients = 150


clusterBenchMark :: IO ()
Expand Down Expand Up @@ -42,16 +43,16 @@ clusterBenchMark = do
start <- newEmptyMVar
done <- newEmptyMVar
replicateM_ nClients $ forkIO $ do
runRedis conn $ forever $ do
action <- liftIO $ takeMVar start
action
liftIO $ putMVar done ()
forever $ do
(reps,action) <- liftIO $ takeMVar start
replicateM_ reps $ runRedis conn action
liftIO $ putMVar done ()

let timeAction name nActions action = do
startT <- getCurrentTime
-- each clients runs ACTION nRepetitions times
let nRepetitions = nRequests `div` nClients `div` nActions
replicateM_ nClients $ putMVar start (replicateM_ nRepetitions action)
replicateM_ nClients $ putMVar start (nRepetitions,action)
replicateM_ nClients $ takeMVar done
stopT <- getCurrentTime
let deltaT = realToFrac $ diffUTCTime stopT startT
Expand Down Expand Up @@ -83,6 +84,13 @@ clusterBenchMark = do
Right Nothing -> return ()
_ -> error "error"
return ()

timeAction "get random keys" 1 $ do
key <- randomRIO (0::Int, 16000)
get (BS.pack (show key)) >>= \case
Right Nothing -> return ()
_ -> error "error"
return ()

timeAction "get pipelined 10" 10 $ do
res <- replicateM 10 (get "k1")
Expand Down
6 changes: 4 additions & 2 deletions hedis.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ library
errors,
network-uri,
unliftio-core,
random
random,
extra
if !impl(ghc >= 8.0)
build-depends:
semigroups >= 0.11 && < 0.19
Expand Down Expand Up @@ -122,7 +123,8 @@ benchmark hedis-benchmark
mtl >= 2.0,
hedis,
bytestring,
time >= 1.2
time >= 1.2,
random
other-modules: ClusterBenchmark
ghc-options: -O2 -Wall -rtsopts
if flag(dev)
Expand Down
324 changes: 160 additions & 164 deletions src/Database/Redis/Cluster.hs

Large diffs are not rendered by default.

165 changes: 82 additions & 83 deletions src/Database/Redis/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-# LANGUAGE NamedFieldPuns #-}
module Database.Redis.Connection where

import Control.Exception
import qualified Control.Monad.Catch as Catch
import Control.Monad.IO.Class(liftIO, MonadIO)
import Control.Monad(when)
import Control.Concurrent.MVar(MVar, newMVar, readMVar, modifyMVar_)
import Control.Monad(when,foldM)

import Control.Concurrent.MVar(modifyMVar)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import Data.Functor(void)
Expand All @@ -29,9 +30,10 @@ import Text.Read (readMaybe)
import qualified Database.Redis.ProtocolPipelining as PP
import Database.Redis.Core(Redis, runRedisInternal, runRedisClusteredInternal)
import Database.Redis.Protocol(Reply(..))
import Database.Redis.Cluster(ShardMap(..), Node(..), Shard(..), NodeID, NodeConnection)
import Database.Redis.Cluster(ShardMap(..), Node(..), Shard(..))
import qualified Database.Redis.Cluster as Cluster
import qualified Database.Redis.ConnectionContext as CC
import qualified System.Timeout as T

import Database.Redis.Commands
( ping
Expand All @@ -43,7 +45,6 @@ import Database.Redis.Commands
, ClusterSlotsResponse(..)
, ClusterSlotsResponseEntry(..)
, ClusterSlotsNode(..))
import qualified System.Timeout as T

--------------------------------------------------------------------------------
-- Connection
Expand All @@ -53,7 +54,7 @@ import qualified System.Timeout as T
-- 'connect' function to create one.
data Connection
= NonClusteredConnection (Pool PP.Connection)
| ClusteredConnection (MVar ShardMap) (Pool Cluster.Connection) ConnectInfo
| ClusteredConnection ConnectInfo Cluster.Connection

-- |Information for connnecting to a Redis server.
--
Expand Down Expand Up @@ -185,7 +186,7 @@ checkedConnect connInfo = do
-- |Destroy all idle resources in the pool.
disconnect :: Connection -> IO ()
disconnect (NonClusteredConnection pool) = destroyAllResources pool
disconnect (ClusteredConnection _ pool _) = destroyAllResources pool
disconnect (ClusteredConnection _ conn) = Cluster.destroyNodeResources conn

-- | Memory bracket around 'connect' and 'disconnect'.
withConnect :: (Catch.MonadMask m, MonadIO m) => ConnectInfo -> (Connection -> m c) -> m c
Expand All @@ -203,8 +204,8 @@ withCheckedConnect connInfo = bracket (checkedConnect connInfo) disconnect
runRedis :: Connection -> Redis a -> IO a
runRedis (NonClusteredConnection pool) redis =
withResource pool $ \conn -> runRedisInternal conn redis
runRedis (ClusteredConnection _ pool bootstrapConnInfo) redis =
withResource pool $ \conn -> runRedisClusteredInternal conn (refreshShardMap conn bootstrapConnInfo) redis
runRedis (ClusteredConnection bootstrapConnInfo conn) redis =
runRedisClusteredInternal conn (refreshShardMap bootstrapConnInfo conn) redis

newtype ClusterConnectError = ClusterConnectError Reply
deriving (Eq, Show, Typeable)
Expand All @@ -220,39 +221,42 @@ instance Exception ClusterConnectError
-- - MOVE, SELECT
-- - PUBLISH, SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE, PUNSUBSCRIBE, RESET
connectCluster :: ConnectInfo -> IO Connection
connectCluster bootstrapConnInfo = do
connectCluster bootstrapConnInfo@ConnInfo{connectMaxConnections,connectMaxIdleTime} = do
conn <- createConnection bootstrapConnInfo
slotsResponse <- runRedisInternal conn clusterSlots
shardMapVar <- case slotsResponse of
shardMap <- case slotsResponse of
Left e -> throwIO $ ClusterConnectError e
Right slots -> do
shardMap <- shardMapFromClusterSlotsResponse slots
newMVar shardMap
Right slots -> shardMapFromClusterSlotsResponse slots
commandInfos <- runRedisInternal conn command
case commandInfos of
Left e -> throwIO $ ClusterConnectError e
Right infos -> do
let
isConnectionReadOnly = connectReadOnly bootstrapConnInfo
clusterConnection = Cluster.connect (connectWithAuth bootstrapConnInfo) infos shardMapVar (clusterConnectTimeoutinUs bootstrapConnInfo) isConnectionReadOnly refreshShardMapWithNodeConn
pool <- createPool (clusterConnect isConnectionReadOnly clusterConnection) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
return $ ClusteredConnection shardMapVar pool bootstrapConnInfo
where
clusterConnect :: Bool -> IO Cluster.Connection -> IO Cluster.Connection
clusterConnect readOnlyConnection connection = do
clusterConn@(Cluster.Connection nodeMapMvar _ _ _ _) <- connection
nodeMap <- Cluster.hasLocked $ readMVar nodeMapMvar
nodesConns <- sequence $ ( PP.fromCtx . (\(Cluster.NodeConnection ctx _ _) -> ctx ) . snd) <$> (HM.toList nodeMap)
when readOnlyConnection $
mapM_ (\conn -> do
PP.beginReceiving conn
runRedisInternal conn readOnly
) nodesConns
return clusterConn
let withAuth = connectWithAuth bootstrapConnInfo
clusterConnection <- Cluster.createClusterConnectionPools withAuth connectMaxConnections connectMaxIdleTime infos shardMap
return $ ClusteredConnection bootstrapConnInfo clusterConnection

clusterConnectTimeoutinUs :: ConnectInfo -> Maybe Int
clusterConnectTimeoutinUs bootstrapConnInfo =
round . (1000000 *) <$> connectTimeout bootstrapConnInfo
connectWithAuth :: ConnectInfo -> Cluster.Host -> CC.PortID -> IO CC.ConnectionContext
connectWithAuth ConnInfo{connectTLSParams,connectAuth,connectReadOnly,connectTimeout} host port = do
conn <- PP.connect host port $ clusterConnectTimeoutinUs <$> connectTimeout
conn' <- case connectTLSParams of
Nothing -> return conn
Just tlsParams -> PP.enableTLS tlsParams conn
PP.beginReceiving conn'
runRedisInternal conn' $ do
-- AUTH
case connectAuth of
Nothing -> return ()
Just pass -> do
resp <- auth pass
case resp of
Left r -> liftIO $ throwIO $ ConnectAuthError r
_ -> return ()
when connectReadOnly $ do
runRedisInternal conn' readOnly >> return()
return $ PP.toCtx conn'

clusterConnectTimeoutinUs :: Time.NominalDiffTime -> Int
clusterConnectTimeoutinUs = round . (1000000 *)

shardMapFromClusterSlotsResponse :: ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr mkShardMap (pure IntMap.empty) clusterSlotsResponseEntries where
Expand All @@ -271,59 +275,54 @@ shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr m
in
Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort)

refreshShardMap :: Cluster.Connection -> ConnectInfo -> IO ShardMap
refreshShardMap (Cluster.Connection nodeConnsMvar _ _ _ _) bootstrapConnInfo = do
nodeConns <- Cluster.hasLocked $ readMVar nodeConnsMvar
updatedShardMap <- refreshShardMapWithNodeConn (HM.elems nodeConns)
updatedConn <- getConnectionsMapFromShardMap updatedShardMap bootstrapConnInfo
Cluster.hasLocked $ modifyMVar_ nodeConnsMvar (const (pure updatedConn))
pure updatedShardMap

getConnectionsMapFromShardMap :: ShardMap -> ConnectInfo -> IO (HM.HashMap NodeID NodeConnection)
getConnectionsMapFromShardMap shardMap bootstrapConnInfo = do
let nodes = nub $ Cluster.nodes shardMap
connectNodeWithAuth = Cluster.connectNode (connectWithAuth bootstrapConnInfo) (clusterConnectTimeoutinUs bootstrapConnInfo)
connRes <- mapM (\node ->
connectNodeWithAuth node `catch` (\(err :: SomeException) -> throwIO (Cluster.RefreshNodesException $ show err))) nodes
return $ foldl (\acc (v, nc) -> HM.insert v nc acc) mempty connRes

connectWithAuth :: ConnectInfo -> Cluster.Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext
connectWithAuth bootstrapConnInfo host port timeout = do
conn <- PP.connect host port timeout
conn' <- case connectTLSParams bootstrapConnInfo of
Nothing -> return conn
Just tlsParams -> PP.enableTLS tlsParams conn
PP.beginReceiving conn'

runRedisInternal conn' $ do
-- AUTH
case connectAuth bootstrapConnInfo of
Nothing -> return ()
Just pass -> do
resp <- auth pass
case resp of
Left r -> liftIO $ throwIO $ ConnectAuthError r
_ -> return ()
return $ PP.toCtx conn'
refreshShardMap :: ConnectInfo -> Cluster.Connection -> IO ShardMap
refreshShardMap connectInfo@ConnInfo{connectMaxConnections,connectMaxIdleTime} (Cluster.Connection shardNodeVar _ _) = do
modifyMVar shardNodeVar $ \(_, oldNodeConnMap) -> do
newShardMap <- refreshShardMapWithNodeConn (HM.elems oldNodeConnMap)
newNodeConnMap <- updateNodeConnections newShardMap oldNodeConnMap
return ((newShardMap, newNodeConnMap), newShardMap)
where
withAuth :: Cluster.Host -> CC.PortID -> IO CC.ConnectionContext
withAuth = connectWithAuth connectInfo
updateNodeConnections :: ShardMap -> HM.HashMap Cluster.NodeID Cluster.NodeConnection -> IO (HM.HashMap Cluster.NodeID Cluster.NodeConnection)
updateNodeConnections newShardMap oldNodeConnMap = do
foldM (\acc node@(Cluster.Node nodeid _ _ _) ->
case HM.lookup nodeid oldNodeConnMap of
Just nodeconn -> return $ HM.insert nodeid nodeconn acc
Nothing -> do
(_,nodeConnPool) <- Cluster.createNodePool withAuth connectMaxConnections connectMaxIdleTime node
return $ HM.insert nodeid nodeConnPool acc
) HM.empty (nub $ Cluster.nodes newShardMap)

refreshShardMapWithNodeConn :: [Cluster.NodeConnection] -> IO ShardMap
refreshShardMapWithNodeConn [] = throwIO $ ClusterConnectError (Error "Couldn't refresh shardMap due to connection error")
refreshShardMapWithNodeConn nodeConnsList = do
selectedIdx <- randomRIO (0, (length nodeConnsList) - 1)
let (Cluster.NodeConnection ctx _ _) = nodeConnsList !! selectedIdx
pipelineConn <- PP.fromCtx ctx
envTimeout <- fromMaybe (10 ^ (5 :: Int)) . (>>= readMaybe) <$> lookupEnv "REDIS_CLUSTER_SLOTS_TIMEOUT"
raceResult <- T.timeout envTimeout (try $ refreshShardMapWithConn pipelineConn True)-- racing with delay of default 1 ms
case raceResult of
Nothing -> do
print $ "TimeoutForConnection " <> show ctx
throwIO $ Cluster.TimeoutException "ClusterSlots Timeout"
Just eiShardMapResp ->
case eiShardMapResp of
Right shardMap -> pure shardMap
Left (err :: SomeException) -> do
print $ "ShardMapRefreshError-" <> show err
throwIO $ ClusterConnectError (Error "Couldn't refresh shardMap due to connection error")
let numOfNodes = length nodeConnsList
selectedIdx <- randomRIO (0, length nodeConnsList - 1)
let (Cluster.NodeConnection pool _) = nodeConnsList !! selectedIdx
eresp <- try $ refreshShardMapWithPool pool
case eresp of
Left (_::SomeException) -> do -- retry on other node
let otherSelectedIdx = (selectedIdx + 1) `mod` numOfNodes
(Cluster.NodeConnection otherPool _) = nodeConnsList !! otherSelectedIdx
refreshShardMapWithPool otherPool
Right shardMap -> return shardMap
where
refreshShardMapWithPool pool = withResource pool $
\(ctx,_) -> do
pipelineConn <- PP.fromCtx ctx
envTimeout <- fromMaybe (10 ^ (5 :: Int)) . (>>= readMaybe) <$> lookupEnv "REDIS_CLUSTER_SLOTS_TIMEOUT"
eresp <- T.timeout envTimeout (try $ refreshShardMapWithConn pipelineConn True) -- racing with delay of default 100 ms
case eresp of
Nothing -> do
print $ "TimeoutForConnection " <> show ctx
throwIO $ Cluster.TimeoutException "ClusterSlots Timeout"
Just eiShardMapResp ->
case eiShardMapResp of
Right shardMap -> pure shardMap
Left (err :: SomeException) -> do
print $ "ShardMapRefreshError-" <> show err
throwIO $ ClusterConnectError (Error "Couldn't refresh shardMap due to connection error")

refreshShardMapWithConn :: PP.Connection -> Bool -> IO ShardMap
refreshShardMapWithConn pipelineConn _ = do
Expand Down
7 changes: 5 additions & 2 deletions src/Database/Redis/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Prelude
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Control.Concurrent.MVar (newMVar)
import Control.Monad.Reader
import qualified Data.ByteString as B
import Data.IORef
Expand Down Expand Up @@ -75,7 +76,9 @@ runRedisInternal conn (Redis redis) = do
runRedisClusteredInternal :: Cluster.Connection -> IO ShardMap -> Redis a -> IO a
runRedisClusteredInternal connection refreshShardmapAction (Redis redis) = do
ref <- newIORef (SingleLine "nobody will ever see this")
r <- runReaderT redis (ClusteredEnv refreshShardmapAction connection ref)
stateVar <- liftIO $ newMVar $ Cluster.Pending []
pipelineVar <- liftIO $ newMVar $ Cluster.Pipeline stateVar
r <- runReaderT redis (ClusteredEnv refreshShardmapAction connection ref pipelineVar)
readIORef ref >>= (`seq` return ())
return r

Expand Down Expand Up @@ -117,7 +120,7 @@ sendRequest req = do
setLastReply r
return r
ClusteredEnv{..} -> do
r <- liftIO $ Cluster.requestPipelined refreshAction connection req
r <- liftIO $ Cluster.requestPipelined refreshAction connection req pipeline
lift (writeIORef clusteredLastReply r)
return r
returnDecode r'
Expand Down
2 changes: 2 additions & 0 deletions src/Database/Redis/Core/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Control.Monad.Fail (MonadFail)
import Control.Monad.Reader
import Data.IORef
import Database.Redis.Protocol
import Control.Concurrent.MVar (MVar)
import Control.Monad.IO.Unlift (MonadUnliftIO)
import qualified Database.Redis.ProtocolPipelining as PP
import qualified Database.Redis.Cluster as Cluster
Expand All @@ -31,6 +32,7 @@ data RedisEnv
{ refreshAction :: IO Cluster.ShardMap
, connection :: Cluster.Connection
, clusteredLastReply :: IORef Reply
, pipeline :: MVar Cluster.Pipeline
}

envLastReply :: RedisEnv -> IORef Reply
Expand Down
2 changes: 1 addition & 1 deletion src/Database/Redis/PubSub.hs
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ pubSubForever (Connection.NonClusteredConnection pool) ctrl onInitialLoad = with
(Right (Left err)) -> throwIO err
(Left (Left err)) -> throwIO err
_ -> return () -- should never happen, since threads exit only with an error
pubSubForever (Connection.ClusteredConnection _ _ _) _ _ = undefined
pubSubForever (Connection.ClusteredConnection _ _) _ _ = undefined


------------------------------------------------------------------------------
Expand Down
Loading