Skip to content

Commit

Permalink
Allow "await WgpuAwaitable(..)"
Browse files Browse the repository at this point in the history
The object returned by WgpuAwaitable is now directly awaitable. Make Almar happy.
  • Loading branch information
fyellin committed Nov 6, 2024
1 parent 3b024be commit 81cf2f7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
3 changes: 2 additions & 1 deletion tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import anyio

import pytest
Expand Down Expand Up @@ -28,7 +29,7 @@ def poll_function():
awaitable = WgpuAwaitable("test", callback, finalizer, poll_function)

if use_async:
result = await awaitable.async_wait()
result = await awaitable
else:
result = awaitable.sync_wait()
assert result == 10 * 10
Expand Down
12 changes: 6 additions & 6 deletions wgpu/backends/wgpu_native/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ async def request_adapter_async(
force_fallback_adapter=force_fallback_adapter,
canvas=canvas,
) # no-cover
return await awaitable.async_wait()
return await awaitable

def _request_adapter(
self, *, power_preference=None, force_fallback_adapter=False, canvas=None
Expand Down Expand Up @@ -873,7 +873,7 @@ async def request_device_async(
)
# Note that although we claim this function is async, the callback always
# happens inside the call to libf.wgpuAdapterRequestDevice
return await awaitable.async_wait()
return await awaitable

def _request_device(
self,
Expand Down Expand Up @@ -1602,7 +1602,7 @@ def finalizer(id):
self._internal, descriptor, callback, ffi.NULL
)

return await awaitable.async_wait()
return await awaitable

def _create_compute_pipeline_descriptor(
self,
Expand Down Expand Up @@ -1703,7 +1703,7 @@ def finalizer(id):
self._internal, descriptor, callback, ffi.NULL
)

return await awaitable.async_wait()
return await awaitable

def _create_render_pipeline_descriptor(
self,
Expand Down Expand Up @@ -2079,7 +2079,7 @@ async def map_async(
self, mode: flags.MapMode, offset: int = 0, size: Optional[int] = None
):
awaitable = self._map(mode, offset, size) # for now
return await awaitable.async_wait()
return await awaitable

def _map(self, mode, offset=0, size=None):
sync_on_read = True
Expand Down Expand Up @@ -3539,7 +3539,7 @@ def on_submitted_work_done_sync(self):

async def on_submitted_work_done_async(self):
awaitable = self._on_submitted_word_done()
await awaitable.async_wait()
await awaitable

def _on_submitted_word_done(self):
@ffi.callback("void(WGPUQueueWorkDoneStatus, void*)")
Expand Down
48 changes: 28 additions & 20 deletions wgpu/backends/wgpu_native/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,30 +251,11 @@ def set_result(self, result):
def set_error(self, error):
self.result = (None, error)

def sync_wait(self):
if not self.poll_function:
if self.result is None:
raise RuntimeError("Expected callback to have already happened")
else:
while not self._is_done():
time.sleep(self.SLEEP_TIME)
return self.finish()

async def async_wait(self):
if not self.poll_function:
if self.result is None:
raise RuntimeError("Expected callback to have already happened")
else:
while not self._is_done():
# A bug in anyio prevents us from waiting on an Event()
await anyio.sleep(self.SLEEP_TIME)
return self.finish()

def _is_done(self):
self.poll_function()
return self.result is not None or time.perf_counter() > self.maxtime

def finish(self):
def _finish(self):
if not self.result:
raise RuntimeError(f"Waiting for {self.title} timed out.")
result, error = self.result
Expand All @@ -283,6 +264,33 @@ def finish(self):
else:
return self.finalizer(result)

def sync_wait(self):
if self.result is not None:
pass
elif not self.poll_function:
raise RuntimeError("Expected callback to have already happened")
else:
while not self._is_done():
time.sleep(self.SLEEP_TIME)
return self._finish()

def async_wait(self):
return self

def __await__(self):
# There is no documentation on what __await__() is supposed to return, but we
# can certainly copy from a function that *does* know what to return
async def wait_for_callback():
if self.result is not None:
return
if not self.poll_function:
raise RuntimeError("Expected callback to have already happened")
while not self._is_done():
await anyio.sleep(self.SLEEP_TIME)

yield from wait_for_callback().__await__()
return self._finish()


class ErrorHandler:
"""Object that logs errors, with the option to collect incoming
Expand Down

0 comments on commit 81cf2f7

Please sign in to comment.