Skip to content

Commit

Permalink
Generalize the base monad of fold-like operations
Browse files Browse the repository at this point in the history
Previously the fold-like operations were restricted to fold operations
in `IO`, greatly limiting their usefulness. Here we generalize them to
any `MonadMask`, provided by the widely-used `exceptions` library.

Resolves #9.
  • Loading branch information
bgamari committed Sep 30, 2016
1 parent c5e489a commit 5a418fc
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 95 deletions.
1 change: 1 addition & 0 deletions postgresql-simple.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Library
bytestring-builder,
case-insensitive,
containers,
exceptions,
hashable,
postgresql-libpq >= 0.9 && < 0.10,
template-haskell,
Expand Down
139 changes: 72 additions & 67 deletions src/Database/PostgreSQL/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ module Database.PostgreSQL.Simple
import Data.ByteString.Builder
( Builder, byteString, char8, intDec )
import Control.Applicative ((<$>))
import Control.Exception as E
import Control.Monad (unless)
import Control.Monad (unless, void)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.List (intersperse)
Expand All @@ -140,6 +139,8 @@ import Database.PostgreSQL.Simple.Transaction
import Database.PostgreSQL.Simple.TypeInfo
import qualified Database.PostgreSQL.LibPQ as PQ
import qualified Data.ByteString.Char8 as B
import Control.Monad.IO.Class
import Control.Monad.Catch as E
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict

Expand Down Expand Up @@ -429,24 +430,24 @@ queryWith_ parser conn q@(Query que) = do
--
-- * 'SqlError': the postgresql backend returned an error, e.g.
-- a syntax or type error, or an incorrect table or column name.
fold :: ( FromRow row, ToRow params )
fold :: ( MonadIO m, MonadMask m, FromRow row, ToRow params )
=> Connection
-> Query
-> params
-> a
-> (a -> row -> IO a)
-> IO a
-> (a -> row -> m a)
-> m a
fold = foldWithOptions defaultFoldOptions

-- | A version of 'fold' taking a parser as an argument
foldWith :: ( ToRow params )
foldWith :: ( MonadIO m, MonadMask m, ToRow params )
=> RowParser row
-> Connection
-> Query
-> params
-> a
-> (a -> row -> IO a)
-> IO a
-> (a -> row -> m a)
-> m a
foldWith = foldWithOptionsAndParser defaultFoldOptions

-- | Number of rows to fetch at a time. 'Automatic' currently defaults
Expand Down Expand Up @@ -475,77 +476,80 @@ defaultFoldOptions = FoldOptions {
-- accordingly. If the connection is already in a transaction,
-- then the existing transaction is used and thus the 'transactionMode'
-- option is ignored.
foldWithOptions :: ( FromRow row, ToRow params )
foldWithOptions :: ( MonadIO m, MonadMask m, FromRow row, ToRow params )
=> FoldOptions
-> Connection
-> Query
-> params
-> a
-> (a -> row -> IO a)
-> IO a
-> (a -> row -> m a)
-> m a
foldWithOptions opts = foldWithOptionsAndParser opts fromRow

-- | A version of 'foldWithOptions' taking a parser as an argument
foldWithOptionsAndParser :: (ToRow params)
foldWithOptionsAndParser :: ( MonadIO m, MonadMask m, ToRow params )
=> FoldOptions
-> RowParser row
-> Connection
-> Query
-> params
-> a
-> (a -> row -> IO a)
-> IO a
-> (a -> row -> m a)
-> m a
foldWithOptionsAndParser opts parser conn template qs a f = do
q <- formatQuery conn template qs
q <- liftIO $ formatQuery conn template qs
doFold opts parser conn template (Query q) a f

-- | A version of 'fold' that does not perform query substitution.
fold_ :: (FromRow r) =>
Connection
fold_ :: ( MonadIO m, MonadMask m, FromRow r )
=> Connection
-> Query -- ^ Query.
-> a -- ^ Initial state for result consumer.
-> (a -> r -> IO a) -- ^ Result consumer.
-> IO a
-> (a -> r -> m a) -- ^ Result consumer.
-> m a
fold_ = foldWithOptions_ defaultFoldOptions

-- | A version of 'fold_' taking a parser as an argument
foldWith_ :: RowParser r
foldWith_ :: ( MonadIO m, MonadMask m)
=> RowParser r
-> Connection
-> Query
-> a
-> (a -> r -> IO a)
-> IO a
-> (a -> r -> m a)
-> m a
foldWith_ = foldWithOptionsAndParser_ defaultFoldOptions

foldWithOptions_ :: (FromRow r) =>
FoldOptions
foldWithOptions_ :: ( MonadIO m, MonadMask m, FromRow r)
=> FoldOptions
-> Connection
-> Query -- ^ Query.
-> a -- ^ Initial state for result consumer.
-> (a -> r -> IO a) -- ^ Result consumer.
-> IO a
-> (a -> r -> m a) -- ^ Result consumer.
-> m a
foldWithOptions_ opts conn query a f = doFold opts fromRow conn query query a f

-- | A version of 'foldWithOptions_' taking a parser as an argument
foldWithOptionsAndParser_ :: FoldOptions
foldWithOptionsAndParser_ :: ( MonadIO m, MonadMask m )
=> FoldOptions
-> RowParser r
-> Connection
-> Query -- ^ Query.
-> a -- ^ Initial state for result consumer.
-> (a -> r -> IO a) -- ^ Result consumer.
-> IO a
-> (a -> r -> m a) -- ^ Result consumer.
-> m a
foldWithOptionsAndParser_ opts parser conn query a f = doFold opts parser conn query query a f

doFold :: FoldOptions
doFold :: ( MonadIO m, MonadMask m )
=> FoldOptions
-> RowParser row
-> Connection
-> Query
-> Query
-> a
-> (a -> row -> IO a)
-> IO a
-> (a -> row -> m a)
-> m a
doFold FoldOptions{..} parser conn _template q a0 f = do
stat <- withConnection conn PQ.transactionStatus
stat <- liftIO $ withConnection conn PQ.transactionStatus
case stat of
PQ.TransIdle -> withTransactionMode transactionMode conn go
PQ.TransInTrans -> go
Expand All @@ -563,15 +567,15 @@ doFold FoldOptions{..} parser conn _template q a0 f = do
-- Not sure what this means.
where
declare = do
name <- newTempName conn
_ <- execute_ conn $ mconcat
name <- liftIO $ newTempName conn
_ <- liftIO $ execute_ conn $ mconcat
[ "DECLARE ", name, " NO SCROLL CURSOR FOR ", q ]
return name
close name =
(execute_ conn ("CLOSE " <> name) >> return ()) `E.catch` \ex ->
(void $ liftIO $ execute_ conn ("CLOSE " <> name)) `E.catch` \ex ->
-- Don't throw exception if CLOSE failed because the transaction is
-- aborted. Otherwise, it will throw away the original error.
unless (isFailedTransactionError ex) $ throwIO ex
unless (isFailedTransactionError ex) $ throwM ex

go = bracket declare close $ \(Query name) ->
let q = toByteString (byteString "FETCH FORWARD "
Expand All @@ -580,20 +584,20 @@ doFold FoldOptions{..} parser conn _template q a0 f = do
<> byteString name
)
loop a = do
result <- exec conn q
status <- PQ.resultStatus result
result <- liftIO $ exec conn q
status <- liftIO $ PQ.resultStatus result
case status of
PQ.TuplesOk -> do
nrows <- PQ.ntuples result
ncols <- PQ.nfields result
nrows <- liftIO $ PQ.ntuples result
ncols <- liftIO $ PQ.nfields result
if nrows > 0
then do
let inner a row = do
x <- getRowWith parser row ncols conn result
x <- liftIO $ getRowWith parser row ncols conn result
f a x
foldM' inner a 0 (nrows - 1) >>= loop
else return a
_ -> throwResultError "fold" result status
_ -> liftIO $ throwResultError "fold" result status
in loop a0

-- FIXME: choose the Automatic chunkSize more intelligently
Expand All @@ -607,44 +611,45 @@ doFold FoldOptions{..} parser conn _template q a0 f = do
Fixed n -> n

-- | A version of 'fold' that does not transform a state value.
forEach :: (ToRow q, FromRow r) =>
Connection
forEach :: ( MonadIO m, MonadMask m, ToRow q, FromRow r )
=> Connection
-> Query -- ^ Query template.
-> q -- ^ Query parameters.
-> (r -> IO ()) -- ^ Result consumer.
-> IO ()
-> (r -> m ()) -- ^ Result consumer.
-> m ()
forEach = forEachWith fromRow
{-# INLINE forEach #-}

-- | A version of 'forEach' taking a parser as an argument
forEachWith :: ( ToRow q )
forEachWith :: ( MonadIO m, MonadMask m, ToRow q )
=> RowParser r
-> Connection
-> Query
-> q
-> (r -> IO ())
-> IO ()
-> (r -> m ())
-> m ()
forEachWith parser conn template qs = foldWith parser conn template qs () . const
{-# INLINE forEachWith #-}

-- | A version of 'forEach' that does not perform query substitution.
forEach_ :: (FromRow r) =>
Connection
forEach_ :: ( MonadIO m, MonadMask m, FromRow r )
=> Connection
-> Query -- ^ Query template.
-> (r -> IO ()) -- ^ Result consumer.
-> IO ()
-> (r -> m ()) -- ^ Result consumer.
-> m ()
forEach_ = forEachWith_ fromRow
{-# INLINE forEach_ #-}

forEachWith_ :: RowParser r
forEachWith_ :: ( MonadIO m , MonadMask m )
=> RowParser r
-> Connection
-> Query
-> (r -> IO ())
-> IO ()
-> (r -> m ())
-> m ()
forEachWith_ parser conn template = foldWith_ parser conn template () . const
{-# INLINE forEachWith_ #-}

forM' :: (Ord n, Num n) => n -> n -> (n -> IO a) -> IO [a]
forM' :: (Monad m, Ord n, Num n) => n -> n -> (n -> m a) -> m [a]
forM' lo hi m = loop hi []
where
loop !n !as
Expand All @@ -654,7 +659,7 @@ forM' lo hi m = loop hi []
loop (n-1) (a:as)
{-# INLINE forM' #-}

foldM' :: (Ord n, Num n) => (a -> n -> IO a) -> a -> n -> n -> IO a
foldM' :: (Monad m, Ord n, Num n) => (a -> n -> m a) -> a -> n -> n -> m a
foldM' f a lo hi = loop a lo
where
loop a !n
Expand All @@ -669,18 +674,18 @@ finishQueryWith parser conn q result = do
status <- PQ.resultStatus result
case status of
PQ.EmptyQuery ->
throwIO $ QueryError "query: Empty query" q
throwM $ QueryError "query: Empty query" q
PQ.CommandOk ->
throwIO $ QueryError "query resulted in a command response" q
throwM $ QueryError "query resulted in a command response" q
PQ.TuplesOk -> do
nrows <- PQ.ntuples result
ncols <- PQ.nfields result
forM' 0 (nrows-1) $ \row ->
getRowWith parser row ncols conn result
PQ.CopyOut ->
throwIO $ QueryError "query: COPY TO is not supported" q
throwM $ QueryError "query: COPY TO is not supported" q
PQ.CopyIn ->
throwIO $ QueryError "query: COPY FROM is not supported" q
throwM $ QueryError "query: COPY FROM is not supported" q
PQ.BadResponse -> throwResultError "query" result status
PQ.NonfatalError -> throwResultError "query" result status
PQ.FatalError -> throwResultError "query" result status
Expand All @@ -698,16 +703,16 @@ getRowWith parser row ncols conn result = do
v <- PQ.getvalue result row c
return ( tinfo
, fmap ellipsis v )
throw (ConversionFailed
throwM (ConversionFailed
(show (unCol ncols) ++ " values: " ++ show vals)
Nothing
""
(show (unCol col) ++ " slots in target type")
"mismatch between number of columns to \
\convert and number in target type")
Errors [] -> throwIO $ ConversionFailed "" Nothing "" "" "unknown error"
Errors [x] -> throwIO x
Errors xs -> throwIO $ ManyErrors xs
Errors [] -> throwM $ ConversionFailed "" Nothing "" "" "unknown error"
Errors [x] -> throwM x
Errors xs -> throwM $ ManyErrors xs

ellipsis :: ByteString -> ByteString
ellipsis bs
Expand Down
21 changes: 1 addition & 20 deletions src/Database/PostgreSQL/Simple/Compat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
-- | This is a module of its own, partly because it uses the CPP extension,
-- which doesn't play well with backslash-broken string literals.
module Database.PostgreSQL.Simple.Compat
( mask
, (<>)
( (<>)
, unsafeDupablePerformIO
, toByteString
, scientificBuilder
, toPico
, fromPico
) where

import qualified Control.Exception as E
import Data.Monoid
import Data.ByteString (ByteString)
#if MIN_VERSION_bytestring(0,10,0)
Expand Down Expand Up @@ -43,23 +41,6 @@ import Data.Fixed (Fixed(MkFixed))
import Unsafe.Coerce (unsafeCoerce)
#endif

-- | Like 'E.mask', but backported to base before version 4.3.0.
--
-- Note that the restore callback is monomorphic, unlike in 'E.mask'. This
-- could be fixed by changing the type signature, but it would require us to
-- enable the RankNTypes extension (since 'E.mask' has a rank-3 type). The
-- 'withTransactionMode' function calls the restore callback only once, so we
-- don't need that polymorphism.
mask :: ((IO a -> IO a) -> IO b) -> IO b
#if MIN_VERSION_base(4,3,0)
mask io = E.mask $ \restore -> io restore
#else
mask io = do
b <- E.blocked
E.block $ io $ \m -> if b then m else E.unblock m
#endif
{-# INLINE mask #-}

#if !MIN_VERSION_base(4,5,0)
infixr 6 <>

Expand Down
Loading

0 comments on commit 5a418fc

Please sign in to comment.