diff --git a/src/main/scala/in/rcard/raise4s/Raise.scala b/src/main/scala/in/rcard/raise4s/Raise.scala index 26c0b49..9f0b445 100644 --- a/src/main/scala/in/rcard/raise4s/Raise.scala +++ b/src/main/scala/in/rcard/raise4s/Raise.scala @@ -34,5 +34,13 @@ def ensureNotNull[B, Error](value: B, raise: () => Error)(using r: Raise[Error]) def recover[Error, A](block: Raise[Error] ?=> () => A, recover: Error => A): A = fold(block, ex => throw ex, recover, identity) -def recover[Error, A](block: Raise[Error] ?=> () => A, recover: Error => A, catchBlock: Throwable => A): A = +def recover[Error, A]( + block: Raise[Error] ?=> () => A, + recover: Error => A, + catchBlock: Throwable => A +): A = fold(block, catchBlock, recover, identity) + +def $catch[A](block: () => A, catchBlock: Throwable => A): A = + try block() + catch case ex: Throwable => catchBlock(ex) diff --git a/src/test/scala/in/rcard/raise4s/RaiseSpec.scala b/src/test/scala/in/rcard/raise4s/RaiseSpec.scala index 389abdd..6b05f16 100644 --- a/src/test/scala/in/rcard/raise4s/RaiseSpec.scala +++ b/src/test/scala/in/rcard/raise4s/RaiseSpec.scala @@ -97,4 +97,22 @@ class RaiseSpec extends AnyFlatSpec with Matchers { actual should be(44) } + + "$catch" should "return the value if no exception is thrown" in { + val actual = $catch( + () => 42, + ex => 43 + ) + + actual should be(42) + } + + it should "return the recovery value if an exception is thrown" in { + val actual = $catch( + () => throw new RuntimeException("error"), + ex => 43 + ) + + actual should be(43) + } }