diff --git a/benchmarks/src/main/scala/zio/query/ZQueryBenchmark.scala b/benchmarks/src/main/scala/zio/query/ZQueryBenchmark.scala index 19cead0..15f8a88 100644 --- a/benchmarks/src/main/scala/zio/query/ZQueryBenchmark.scala +++ b/benchmarks/src/main/scala/zio/query/ZQueryBenchmark.scala @@ -24,8 +24,16 @@ class ZQueryBenchmark { def zQueryRunSucceedNowBenchmark() = unsafeRunZIO(ZIO.collectAllDiscard(qs1)) + @Benchmark + def zQuerySingleRunSucceedNowBenchmark() = + unsafeRunZIO(qs1.head) + @Benchmark @OperationsPerInvocation(1000) def zQueryRunSucceedBenchmark() = unsafeRunZIO(ZIO.collectAllDiscard(qs2)) + + @Benchmark + def zQuerySingleRunSucceedBenchmark() = + unsafeRunZIO(qs2.head) } diff --git a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala index 4e38e82..a9fa572 100644 --- a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala +++ b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala @@ -17,6 +17,7 @@ package zio.query import zio._ +import zio.query.ZQuery.disabledCache import zio.query.internal._ import zio.stacktracer.TracingImplicits.disableAutoTrace @@ -550,25 +551,46 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = { import ZQuery.{currentCache, currentScope} - def setRef[V](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], newValue: V): V = { - val oldValue = state.getFiberRefOrNull(fiberRef) - state.setFiberRef(fiberRef, newValue) - oldValue + def resetRef[V <: AnyRef]( + fid: FiberId.Runtime, + oldRefs: FiberRefs, + newRefs: FiberRefs + )( + fiberRef: FiberRef[V] + ): FiberRefs = { + val oldValue = oldRefs.getOrNull(fiberRef) + if (oldValue ne null) newRefs.updatedAs(fid)(fiberRef, oldValue) else newRefs.delete(fiberRef) } - def resetRef[V <: AnyRef](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], oldValue: V): Unit = - if (oldValue ne null) state.setFiberRef(fiberRef, oldValue) else state.deleteFiberRef(fiberRef) - asExitOrElse(null) match { case null => ZIO.uninterruptibleMask { restore => ZIO.withFiberRuntime[R, E, A] { (state, _) => - val scope = QueryScope.make() - val oldCache = setRef(state, currentCache, Some(cache)) - val oldScope = setRef(state, currentScope, scope) + // NOTE: Running a ZQuery requires up to 3 FiberRefs, which can be expensive to use `locally` with for simple queries. + // Therefore, we handle them all together to avoid the added penalty of running `locally` 3 times + val fid = state.id + val scope = QueryScope.make() + val oldRefs = state.getFiberRefs(false) + val newRefs = { + val refs = oldRefs.updatedAs(fid)(currentCache, Some(cache)).updatedAs(fid)(currentScope, scope) + if (refs.getOrNull(disabledCache) ne null) + refs.delete(disabledCache) + else refs + } + state.setFiberRefs(newRefs) restore(runToZIO).exitWith { exit => - resetRef(state, currentCache, oldCache) - resetRef(state, currentScope, oldScope) + val curRefs = state.getFiberRefs(false) + if (curRefs eq newRefs) { + // Cheap and common: FiberRefs were not modified during the execution so we just replace them with the old ones + state.setFiberRefs(oldRefs) + } else { + // FiberRefs were mdified so we need to manually revert each one + var revertedRefs = oldRefs + revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentCache) + revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentScope) + revertedRefs = resetRef(fid, oldRefs, revertedRefs)(disabledCache) + state.setFiberRefs(revertedRefs) + } scope.closeAndExitWith(exit) } } diff --git a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala index 44824c3..b9ae293 100644 --- a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala +++ b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala @@ -383,6 +383,16 @@ object ZQuerySpec extends ZIOBaseSpec { q.run.map { case (c1, c2) => assertTrue(c1.isDefined, c1 == c2) } }, + test("disabling caching is reentrant safe") { + val q = + for { + c1 <- ZQuery.fromZIO(ZQuery.currentCache.get) + c2 <- ZQuery.fromZIO(ZQuery.fromZIO(ZQuery.currentCache.get).cached.run).uncached + c3 <- ZQuery.fromZIO(ZQuery.currentCache.get) + } yield (c1, c2, c3) + + q.run.map { case (c1, c2, c3) => assertTrue(c1.isDefined, c2.isDefined, c1 == c3, c1 != c2) } + }, test("scope is reentrant safe") { val q = for {