Skip to content

Commit

Permalink
Add support for async and generator functions in rx.pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Mar 21, 2024
1 parent 8456e8b commit c5f2902
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 4 deletions.
48 changes: 44 additions & 4 deletions param/reactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"""
from __future__ import annotations

import asyncio
import inspect
import math
import operator
Expand All @@ -97,7 +98,7 @@
register_reference_transform, resolve_ref, resolve_value, transform_reference
)
from .parameters import Boolean, Event
from ._utils import iscoroutinefunction, full_groupby
from ._utils import _to_async_gen, iscoroutinefunction, full_groupby


class Wrapper(Parameterized):
Expand All @@ -123,8 +124,9 @@ class Trigger(Parameterized):

value = Event()

def __init__(self, parameters, **params):
def __init__(self, parameters=None, internal=False, **params):
super().__init__(**params)
self.internal = internal
self.parameters = parameters

class Resolver(Parameterized):
Expand Down Expand Up @@ -719,19 +721,27 @@ def __init__(
self._depth = depth
self._dirty = _current is None
self._dirty_obj = False
self._current_task = None
self._error_state = None
self._current_ = _current
if isinstance(obj, rx) and not prev:
self._prev = obj
else:
self._prev = prev

# Define special trigger parameter if operation has to be lazily evaluated
if operation and (iscoroutinefunction(operation['fn']) or inspect.isgeneratorfunction(operation['fn'])):
self._trigger = Trigger(internal=True)
self._current_ = Undefined
else:
self._trigger = None
self._root = self._compute_root()
self._fn_params = self._compute_fn_params()
self._internal_params = self._compute_params()
# Filter params that external objects depend on, ensuring
# that Trigger parameters do not cause double execution
self._params = [
p for p in self._internal_params if not isinstance(p.owner, Trigger)
p for p in self._internal_params if (not isinstance(p.owner, Trigger) or p.owner.internal)
or any (p not in self._internal_params for p in p.owner.parameters)
]
self._setup_invalidations(depth)
Expand Down Expand Up @@ -794,7 +804,9 @@ def _compute_fn_params(self) -> list[Parameter]:
return args + kwargs

def _compute_params(self) -> list[Parameter]:
ps = self._fn_params
ps = list(self._fn_params)
if self._trigger:
ps.append(self._trigger.param.value)

# Collect parameters on previous objects in chain
prev = self._prev
Expand All @@ -808,6 +820,9 @@ def _compute_params(self) -> list[Parameter]:
return ps

# Accumulate dependencies in args and/or kwargs
for ref in resolve_ref(self._operation['fn']):
if ref not in ps:
ps.append(ref)
for arg in list(self._operation['args'])+list(self._operation['kwargs'].values()):
for ref in resolve_ref(arg):
if ref not in ps:
Expand Down Expand Up @@ -843,13 +858,35 @@ def _setup_invalidations(self, depth: int = 0):
params[0].owner.param._watch(self._invalidate_current, [p.name for p in params], precedence=-1)

def _invalidate_current(self, *events):
if all(event.obj is self._trigger for event in events):
return
self._dirty = True
self._error_state = None

def _invalidate_obj(self, *events):
self._root._dirty_obj = True
self._error_state = None

async def _resolve_async(self, obj):
self._current_task = task = asyncio.current_task()
if inspect.isasyncgen(obj):
async for val in obj:
if self._current_task is not task:
break
self._current_ = val
self._trigger.param.trigger('value')
else:
value = await obj
if self._current_task is task:
self._current_ = value
self._trigger.param.trigger('value')

def _lazy_resolve(self, obj):
from .parameterized import async_executor
if inspect.isgenerator(obj):
obj = _to_async_gen(obj)
async_executor(partial(self._resolve_async, obj))

def _resolve(self):
if self._error_state:
raise self._error_state
Expand All @@ -862,6 +899,9 @@ def _resolve(self):
operation = self._operation
if operation:
obj = self._eval_operation(obj, operation)
if inspect.isasyncgen(obj) or inspect.iscoroutine(obj) or inspect.isgenerator(obj):
self._lazy_resolve(obj)
obj = Skip
if obj is Skip:
raise Skip
except Skip:
Expand Down
95 changes: 95 additions & 0 deletions tests/testreactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,18 @@ async def async_func():
await asyncio.sleep(0.04)
assert async_rx.rx.value == 4

async def test_reactive_pipe_async_func():
async def async_func(value):
print(value)
await asyncio.sleep(0.02)
return value+2

async_rx = rx(0).rx.pipe(async_func)
async_rx.rx.watch()
assert async_rx.rx.value is param.Undefined
await asyncio.sleep(0.04)
assert async_rx.rx.value == 2

async def test_reactive_gen():
def gen():
yield 1
Expand All @@ -551,6 +563,29 @@ def gen():
break
assert rxgen.rx.value == 2

async def test_reactive_gen_pipe():
def gen(val):
yield val+1
time.sleep(0.05)
yield val+2

rxv = rx(0)
rxgen = rxv.rx.pipe(gen)
assert rxgen.rx.value is param.Undefined
await asyncio.sleep(0.04)
assert rxgen.rx.value == 1
for _ in range(3):
await asyncio.sleep(0.05)
if rxgen.rx.value == 2:
break
assert rxgen.rx.value == 2
rxv.rx.value = 2
for _ in range(3):
await asyncio.sleep(0.05)
if rxgen.rx.value == 3:
break
assert rxgen.rx.value == 3

async def test_reactive_gen_with_dep():
def gen(i):
yield i+1
Expand All @@ -571,6 +606,31 @@ def gen(i):
break
assert rxgen.rx.value == 5

async def test_reactive_gen_pipe_with_dep():
def gen(value, i):
yield value+i+1
time.sleep(0.05)
yield value+i+2

irx = rx(0)
rxv = rx(0)
rxgen = rxv.rx.pipe(bind(gen, irx))
rxgen.rx.watch()
assert rxgen.rx.value is param.Undefined
await asyncio.sleep(0.04)
assert rxgen.rx.value == 1
irx.rx.value = 3
await asyncio.sleep(0.04)
assert rxgen.rx.value == 4
for _ in range(3):
await asyncio.sleep(0.05)
if rxgen.rx.value == 5:
break
assert rxgen.rx.value == 5
rxv.rx.value = 5
await asyncio.sleep(0.03)
assert rxgen.rx.value == 9

async def test_reactive_async_gen():
async def gen():
yield 1
Expand All @@ -584,6 +644,19 @@ async def gen():
await asyncio.sleep(0.1)
assert rxgen.rx.value == 2

async def test_reactive_async_gen_pipe():
async def gen(value):
yield value + 1
await asyncio.sleep(0.1)
yield value + 2

rxgen = rx(0).rx.pipe(gen)
assert rxgen.rx.value is param.Undefined
await asyncio.sleep(0.05)
assert rxgen.rx.value == 1
await asyncio.sleep(0.1)
assert rxgen.rx.value == 2

async def test_reactive_async_gen_with_dep():
async def gen(i):
yield i+1
Expand All @@ -601,6 +674,28 @@ async def gen(i):
await asyncio.sleep(0.1)
assert rxgen.rx.value == 5

async def test_reactive_async_gen_pipe_with_dep():
async def gen(value, i):
yield value+i+1
await asyncio.sleep(0.1)
yield value+i+2

irx = rx(0)
rxv = rx(0)
rxgen = rxv.rx.pipe(bind(gen, i=irx))
rxgen.rx.watch()
assert rxgen.rx.value is param.Undefined
await asyncio.sleep(0.05)
assert rxgen.rx.value == 1
irx.rx.value = 3
await asyncio.sleep(0.05)
irx.rx.value = 4
await asyncio.sleep(0.1)
assert rxgen.rx.value == 5
rxv.rx.value = 5
await asyncio.sleep(0.1)
assert rxgen.rx.value == 10

def test_root_invalidation():
arx = rx('a')
brx = rx('b')
Expand Down

0 comments on commit c5f2902

Please sign in to comment.