diff --git a/db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala b/db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala index b032ac02..7e121602 100644 --- a/db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala +++ b/db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala @@ -19,6 +19,8 @@ package com.github.mauricio.async.db import java.nio.charset.Charset import io.netty.buffer.{ByteBufAllocator, PooledByteBufAllocator} +import io.netty.channel.Channel +import io.netty.channel.socket.nio.NioSocketChannel import io.netty.util.CharsetUtil import scala.concurrent.duration._ @@ -45,8 +47,11 @@ object Configuration { * to any value you would like but again, make sure you know what you are doing if you do * change it. * @param allocator the netty buffer allocator to be used + * @param channelClass the netty channel class to use. Should match the type of the event loop group set + * for connections. Defaults to [[NioSocketChannel]] * @param connectTimeout the timeout for connecting to servers * @param testTimeout the timeout for connection tests performed by pools + * @param statementTimeout the optional per-session statement timeout to set in the database * @param queryTimeout the optional query timeout * */ @@ -60,6 +65,8 @@ case class Configuration(username: String, charset: Charset = Configuration.DefaultCharset, maximumMessageSize: Int = 16777216, allocator: ByteBufAllocator = PooledByteBufAllocator.DEFAULT, + channelClass: Class[_ <: Channel] = classOf[NioSocketChannel], connectTimeout: Duration = 5.seconds, testTimeout: Duration = 5.seconds, + statementTimeout: Option[Duration] = None, queryTimeout: Option[Duration] = None) diff --git a/mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/codec/MySQLConnectionHandler.scala b/mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/codec/MySQLConnectionHandler.scala index 792aff77..269cd485 100644 --- a/mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/codec/MySQLConnectionHandler.scala +++ b/mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/codec/MySQLConnectionHandler.scala @@ -68,7 +68,7 @@ class MySQLConnectionHandler( private var currentContext: ChannelHandlerContext = null def connect: Future[MySQLConnectionHandler] = { - this.bootstrap.channel(classOf[NioSocketChannel]) + this.bootstrap.channel(configuration.channelClass) this.bootstrap.handler(new ChannelInitializer[io.netty.channel.Channel]() { override def initChannel(channel: io.netty.channel.Channel): Unit = { diff --git a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala index 733cc5d1..5b79d05e 100644 --- a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala +++ b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala @@ -81,7 +81,7 @@ class PostgreSQLConnectionHandler def connect: Future[PostgreSQLConnectionHandler] = { this.bootstrap.group(this.group) - this.bootstrap.channel(classOf[NioSocketChannel]) + this.bootstrap.channel(configuration.channelClass) this.bootstrap.handler(new ChannelInitializer[channel.Channel]() { override def initChannel(ch: channel.Channel): Unit = { diff --git a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/pool/PostgreSQLConnectionFactory.scala b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/pool/PostgreSQLConnectionFactory.scala index ae3c5255..619d93c1 100644 --- a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/pool/PostgreSQLConnectionFactory.scala +++ b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/pool/PostgreSQLConnectionFactory.scala @@ -51,8 +51,16 @@ class PostgreSQLConnectionFactory( def create: PostgreSQLConnection = { val connection = new PostgreSQLConnection(configuration, group = group, executionContext = executionContext) - Await.result(connection.connect, configuration.connectTimeout) - + val future = configuration.statementTimeout match { + case Some(timeout) => { + connection.connect.flatMap(conn => + conn.sendQuery(s"SET statement_timeout TO ${timeout.toMillis};"))(executionContext) + } + case None => { + connection.connect + } + } + Await.result(future, configuration.connectTimeout) connection }