diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0be841505..dd5739dd3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,7 +52,7 @@ jobs: run: sbt +update - name: Start up Postgres - run: docker-compose up -d + run: docker compose up -d - name: Check Headers run: sbt '++ ${{ matrix.scala }}' headerCheckAll diff --git a/modules/core/src/main/scala/doobie/util/transactor.scala b/modules/core/src/main/scala/doobie/util/transactor.scala index 1eada1ac0..9c19a01e2 100644 --- a/modules/core/src/main/scala/doobie/util/transactor.scala +++ b/modules/core/src/main/scala/doobie/util/transactor.scala @@ -252,7 +252,7 @@ object transactor { val strategy = strategy0 } - def withLogHandler(logHandler: LogHandler[M])(implicit ev: WeakAsync[M]): Transactor.Aux[M, A] = copy( + def withLogHandler(logHandler: LogHandler[M])(implicit ev: Async[M]): Transactor.Aux[M, A] = copy( interpret0 = KleisliInterpreter[M](logHandler).ConnectionInterpreter ) diff --git a/modules/core/src/test/scala/doobie/util/QueryCancellationSuite.scala b/modules/core/src/test/scala/doobie/util/QueryCancellationSuite.scala new file mode 100644 index 000000000..e6d34f2ec --- /dev/null +++ b/modules/core/src/test/scala/doobie/util/QueryCancellationSuite.scala @@ -0,0 +1,53 @@ +// Copyright (c) 2013-2020 Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package doobie.util + +import cats.effect.{IO, Resource} +import cats.implicits.catsSyntaxApplicativeId +import doobie.Transactor +import doobie.* +import doobie.implicits.* +import doobie.syntax.* +import cats.syntax.all.* + +import scala.concurrent.duration.DurationInt + +class QueryCancellationSuite extends munit.FunSuite { + import cats.effect.unsafe.implicits.global + + val xa = Transactor.fromDriverManager[IO]( + driver = "org.h2.Driver", + url = "jdbc:h2:mem:queryspec;DB_CLOSE_DELAY=-1", + user = "sa", + password = "", + logHandler = None + ) + + test("Query cancel") { + val scenario = WeakAsync.liftIO[ConnectionIO].use { elevator => + for { + _ <- sql"CREATE TABLE IF NOT EXISTS example_table ( id INT)".update.run.transact(xa) + _ <- sql"TRUNCATE TABLE example_table".update.run.transact(xa) + _ <- sql"INSERT INTO example_table (id) VALUES (1)".update.run.transact(xa) + _ <- { + sql"select * from example_table for update".query[Int].unique >> elevator.liftIO(IO.never) + }.transact(xa).start + + insertWithLockFiber <- { + for { + _ <- IO.sleep(100.milli) + insertFiber <- sql"UPDATE example_table SET id = 2".update.run.transact(xa).start + _ <- IO.sleep(100.milli) + _ <- insertFiber.cancel + } yield () + }.start + + _ <- IO.race(insertWithLockFiber.join, IO.sleep(5.seconds) >> IO(fail("Cancellation is blocked"))) + result <- sql"SELECT * FROM example_table".query[Int].to[List].transact(xa) + } yield assertEquals(result, List(1)) + } + scenario.unsafeRunSync() + } +} diff --git a/modules/free/src/main/scala/doobie/free/kleisliinterpreter.scala b/modules/free/src/main/scala/doobie/free/kleisliinterpreter.scala index ddf042a5a..7bec6e4f8 100644 --- a/modules/free/src/main/scala/doobie/free/kleisliinterpreter.scala +++ b/modules/free/src/main/scala/doobie/free/kleisliinterpreter.scala @@ -9,10 +9,11 @@ package doobie.free // Library imports import cats.~> import cats.data.Kleisli -import cats.effect.kernel.{ Poll, Sync } +import cats.effect.kernel.{Async, Outcome, Poll, Sync} import cats.free.Free -import doobie.WeakAsync +import doobie.util.cancellation.CancellationForker import doobie.util.log.{LogEvent, LogHandler} + import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration @@ -73,12 +74,12 @@ import doobie.free.callablestatement.{ CallableStatementIO, CallableStatementOp import doobie.free.resultset.{ ResultSetIO, ResultSetOp } object KleisliInterpreter { - def apply[M[_]: WeakAsync](logHandler: LogHandler[M]): KleisliInterpreter[M] = + def apply[M[_]: Async](logHandler: LogHandler[M]): KleisliInterpreter[M] = new KleisliInterpreter[M](logHandler) } // Family of interpreters into Kleisli arrows for some monad M. -class KleisliInterpreter[M[_]](logHandler: LogHandler[M])(implicit val asyncM: WeakAsync[M]) { outer => +class KleisliInterpreter[M[_]](logHandler: LogHandler[M])(implicit val asyncM: Async[M]) { outer => // The 14 interpreters, with definitions below. These can be overridden to customize behavior. lazy val NClobInterpreter: NClobOp ~> Kleisli[M, NClob, *] = new NClobInterpreter { } @@ -100,11 +101,18 @@ class KleisliInterpreter[M[_]](logHandler: LogHandler[M])(implicit val asyncM: W def primitive[J, A](f: J => A): Kleisli[M, J, A] = Kleisli { a => // primitive JDBC methods throw exceptions and so do we when reading values // so catch any non-fatal exceptions and lift them into the effect - try { - asyncM.blocking(f(a)) - } catch { - case scala.util.control.NonFatal(e) => asyncM.raiseError(e) - } + import cats.syntax.all._ + for { + jdbcBlockedFiber <- asyncM.start(asyncM.blocking(f(a)).attempt) + a <- jdbcBlockedFiber.join.flatMap { + case Outcome.Succeeded(fa) => fa.flatMap { + case Left(e) => e.raiseError[M, A] + case Right(value) => value.pure[M] + } + case Outcome.Errored(e) => e.raiseError[M, A] + case Outcome.Canceled() => asyncM.never[A] // Cannot be cancelled out of scope + } + } yield a } def raw[J, A](f: J => A): Kleisli[M, J, A] = primitive(f) def raiseError[J, A](e: Throwable): Kleisli[M, J, A] = Kleisli(_ => asyncM.raiseError(e))