Skip to content

Commit

Permalink
fix query cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
TalkingFoxMid committed Aug 5, 2024
1 parent 9c7c3ad commit 2d86feb
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
run: sbt +update

- name: Start up Postgres
run: docker-compose up -d
run: docker compose up -d

- name: Check Headers
run: sbt '++ ${{ matrix.scala }}' headerCheckAll
Expand Down
2 changes: 1 addition & 1 deletion modules/core/src/main/scala/doobie/util/transactor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ object transactor {
val strategy = strategy0
}

def withLogHandler(logHandler: LogHandler[M])(implicit ev: WeakAsync[M]): Transactor.Aux[M, A] = copy(
def withLogHandler(logHandler: LogHandler[M])(implicit ev: Async[M]): Transactor.Aux[M, A] = copy(
interpret0 =
KleisliInterpreter[M](logHandler).ConnectionInterpreter
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2013-2020 Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package doobie.util

import cats.effect.{IO, Resource}
import cats.implicits.catsSyntaxApplicativeId
import doobie.Transactor
import doobie.*
import doobie.implicits.*
import doobie.syntax.*
import cats.syntax.all.*

import scala.concurrent.duration.DurationInt

class QueryCancellationSuite extends munit.FunSuite {
import cats.effect.unsafe.implicits.global

val xa = Transactor.fromDriverManager[IO](
driver = "org.h2.Driver",
url = "jdbc:h2:mem:queryspec;DB_CLOSE_DELAY=-1",
user = "sa",
password = "",
logHandler = None
)

test("Query cancel") {
val scenario = WeakAsync.liftIO[ConnectionIO].use { elevator =>
for {
_ <- sql"CREATE TABLE IF NOT EXISTS example_table ( id INT)".update.run.transact(xa)
_ <- sql"TRUNCATE TABLE example_table".update.run.transact(xa)
_ <- sql"INSERT INTO example_table (id) VALUES (1)".update.run.transact(xa)
_ <- {
sql"select * from example_table for update".query[Int].unique >> elevator.liftIO(IO.never)
}.transact(xa).start

insertWithLockFiber <- {
for {
_ <- IO.sleep(100.milli)
insertFiber <- sql"UPDATE example_table SET id = 2".update.run.transact(xa).start
_ <- IO.sleep(100.milli)
_ <- insertFiber.cancel
} yield ()
}.start

_ <- IO.race(insertWithLockFiber.join, IO.sleep(5.seconds) >> IO(fail("Cancellation is blocked")))
result <- sql"SELECT * FROM example_table".query[Int].to[List].transact(xa)
} yield assertEquals(result, List(1))
}
scenario.unsafeRunSync()
}
}
26 changes: 17 additions & 9 deletions modules/free/src/main/scala/doobie/free/kleisliinterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ package doobie.free
// Library imports
import cats.~>
import cats.data.Kleisli
import cats.effect.kernel.{ Poll, Sync }
import cats.effect.kernel.{Async, Outcome, Poll, Sync}
import cats.free.Free
import doobie.WeakAsync
import doobie.util.cancellation.CancellationForker
import doobie.util.log.{LogEvent, LogHandler}

import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration

Expand Down Expand Up @@ -73,12 +74,12 @@ import doobie.free.callablestatement.{ CallableStatementIO, CallableStatementOp
import doobie.free.resultset.{ ResultSetIO, ResultSetOp }

object KleisliInterpreter {
def apply[M[_]: WeakAsync](logHandler: LogHandler[M]): KleisliInterpreter[M] =
def apply[M[_]: Async](logHandler: LogHandler[M]): KleisliInterpreter[M] =
new KleisliInterpreter[M](logHandler)
}

// Family of interpreters into Kleisli arrows for some monad M.
class KleisliInterpreter[M[_]](logHandler: LogHandler[M])(implicit val asyncM: WeakAsync[M]) { outer =>
class KleisliInterpreter[M[_]](logHandler: LogHandler[M])(implicit val asyncM: Async[M]) { outer =>

// The 14 interpreters, with definitions below. These can be overridden to customize behavior.
lazy val NClobInterpreter: NClobOp ~> Kleisli[M, NClob, *] = new NClobInterpreter { }
Expand All @@ -100,11 +101,18 @@ class KleisliInterpreter[M[_]](logHandler: LogHandler[M])(implicit val asyncM: W
def primitive[J, A](f: J => A): Kleisli[M, J, A] = Kleisli { a =>
// primitive JDBC methods throw exceptions and so do we when reading values
// so catch any non-fatal exceptions and lift them into the effect
try {
asyncM.blocking(f(a))
} catch {
case scala.util.control.NonFatal(e) => asyncM.raiseError(e)
}
import cats.syntax.all._
for {
jdbcBlockedFiber <- asyncM.start(asyncM.blocking(f(a)).attempt)
a <- jdbcBlockedFiber.join.flatMap {
case Outcome.Succeeded(fa) => fa.flatMap {
case Left(e) => e.raiseError[M, A]
case Right(value) => value.pure[M]
}
case Outcome.Errored(e) => e.raiseError[M, A]
case Outcome.Canceled() => asyncM.never[A] // Cannot be cancelled out of scope
}
} yield a
}
def raw[J, A](f: J => A): Kleisli[M, J, A] = primitive(f)
def raiseError[J, A](e: Throwable): Kleisli[M, J, A] = Kleisli(_ => asyncM.raiseError(e))
Expand Down

0 comments on commit 2d86feb

Please sign in to comment.