diff --git a/docker-compose.yml b/docker-compose.yml index 998912a54..aefe707e0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: mysql: - image: mysql:8.0-debian + image: mysql:8.0 environment: MYSQL_ROOT_PASSWORD: password MYSQL_DATABASE: world diff --git a/modules/core/src/main/scala/doobie/syntax/applicativeerror.scala b/modules/core/src/main/scala/doobie/syntax/applicativeerror.scala index 929a88386..d4a2bd4e6 100644 --- a/modules/core/src/main/scala/doobie/syntax/applicativeerror.scala +++ b/modules/core/src/main/scala/doobie/syntax/applicativeerror.scala @@ -11,6 +11,7 @@ import java.sql.SQLException class ApplicativeErrorOps[M[_], A](self: M[A])(implicit ev: ApplicativeError[M, Throwable]) { def attemptSql: M[Either[SQLException, A]] = C.attemptSql(self) + def attemptSomeSql[B](f: PartialFunction[SQLException, B]): M[Either[B, A]] = C.attemptSomeSql(self)(f) def attemptSqlState: M[Either[SqlState, A]] = C.attemptSqlState(self) def attemptSomeSqlState[B](f: PartialFunction[SqlState, B]): M[Either[B, A]] = C.attemptSomeSqlState(self)(f) def exceptSql(handler: SQLException => M[A]): M[A] = C.exceptSql(self)(handler) diff --git a/modules/core/src/main/scala/doobie/util/catchsql.scala b/modules/core/src/main/scala/doobie/util/catchsql.scala index 439fc763d..33db9889e 100644 --- a/modules/core/src/main/scala/doobie/util/catchsql.scala +++ b/modules/core/src/main/scala/doobie/util/catchsql.scala @@ -22,6 +22,12 @@ object catchsql { case e: SQLException => e.asLeft } + /** Like `attempt` but catches only the defined `SQLException`. */ + def attemptSomeSql[M[_], A, B](ma: M[A])(f: PartialFunction[SQLException, B])(implicit AE: ApplicativeError[M, Throwable]): M[Either[B, A]] = + ma.map(_.asRight[B]).recoverWith { case e: SQLException => + f.lift(e).fold(AE.raiseError[Either[B, A]](e))(b => AE.pure(b.asLeft)) + } + /** Like `attemptSql` but yields only the exception's `SqlState`. */ def attemptSqlState[M[_], A](ma: M[A])( implicit ev: ApplicativeError[M, Throwable] diff --git a/modules/core/src/test/scala/doobie/util/CatchSqlSuite.scala b/modules/core/src/test/scala/doobie/util/CatchSqlSuite.scala index 55c8f6a58..508dcd9fc 100644 --- a/modules/core/src/test/scala/doobie/util/CatchSqlSuite.scala +++ b/modules/core/src/test/scala/doobie/util/CatchSqlSuite.scala @@ -4,8 +4,11 @@ package doobie.util -import cats.effect.{IO} -import doobie.*, doobie.implicits.* +import cats.effect.IO +import doobie.* +import doobie.implicits.* +import org.postgresql.util.{PSQLException, PSQLState} + import java.sql.SQLException class CatchSqlSuite extends munit.FunSuite { @@ -31,16 +34,61 @@ class CatchSqlSuite extends munit.FunSuite { } } - test("attemptSqlState shuold do nothing on success") { + test("attemptSomeSql should do nothing on success") { + assertEquals( + IO.delay(3).attemptSomeSql { + case _: SQLException => 42 + }.unsafeRunSync(), + Right(3)) + } + + test("attemptSomeSql should catch SQLException with matching subtype (1)") { + val e = new SQLException("", SQLSTATE_FOO.value) + assertEquals( + IO.raiseError(e).attemptSomeSql { + case _: SQLException => 42 + }.unsafeRunSync(), + Left(42)) + } + + test("attemptSomeSql should catch SQLException with matching subtype (2)") { + val PSQLSTATE = PSQLState.CHECK_VIOLATION + val e = new PSQLException("", PSQLSTATE) + assertEquals( + IO.raiseError(e).attemptSomeSql { + case exception: PSQLException if exception.getSQLState == PSQLSTATE.getState => 66 + }.unsafeRunSync(), + Left(66)) + } + + test("attemptSomeSql should ignore SQLException with non-matching subtype") { + final case class AnotherSQLException(message: String) extends SQLException(message) + val e = AnotherSQLException("") + intercept[AnotherSQLException] { + IO.raiseError(e).attemptSomeSql { + case exception: PSQLException if exception.getSQLState == "Baz" => 66 + }.unsafeRunSync() + } + } + + test("attemptSomeSql should ignore non-SQLException") { + val e = new IllegalArgumentException + intercept[IllegalArgumentException] { + IO.raiseError(e).attemptSomeSql { + case _: SQLException => 42 + }.unsafeRunSync() + } + } + test("attemptSqlState should do nothing on success") { assertEquals(IO.delay(3).attemptSqlState.unsafeRunSync(), Right(3)) } - test("attemptSqlState shuold catch SQLException") { + test("attemptSqlState should catch SQLException") { val e = new SQLException("", SQLSTATE_FOO.value) assertEquals(IO.raiseError(e).attemptSqlState.unsafeRunSync(), Left(SQLSTATE_FOO)) } - test("attemptSqlState shuold ignore non-SQLException") { + test("attemptSqlState should ignore non-SQLException") { val e = new IllegalArgumentException intercept[IllegalArgumentException] { IO.raiseError(e).attemptSqlState.unsafeRunSync()