Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change task.return to take component-level type + canonopt immediates #431

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions design/mvp/Async.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,16 @@ replaced with `...` to focus on the overall flow of function calls.
```wat
(component
(import "fetch" (func $fetch (param "url" string) (result (list u8))))
(core module $Libc
(memory (export "mem") 1)
(func (export "realloc") (param i32 i32 i32 i32) (result i32) ...)
...
)
Comment on lines +354 to +358
Copy link
Member Author

@lukewagner lukewagner Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: This and most of the changes below in Async.md are pre-existing fixes that were made more evident when thinking about the <canonopt>* immediate of (canon task.return ...) below.

(core module $Main
(import "libc" "mem" (memory 1))
(import "libc" "realloc" (func (param i32 i32 i32 i32) (result i32)))
(import "" "fetch" (func $fetch (param i32 i32) (result i32)))
(import "" "task.return" (func $task_return (param i32)))
(import "" "task.return" (func $task_return (param i32 i32)))
(import "" "task.wait" (func $wait (param i32) (result i32)))
(func (export "summarize") (param i32 i32)
...
Expand All @@ -368,19 +375,25 @@ replaced with `...` to focus on the overall flow of function calls.
...
end
...
call $task_return ;; return the string result
call $task_return ;; return the string result (pointer,length)
...
)
)
(canon lower $fetch async (core func $fetch'))
(canon task.return (core func $task_return))
(canon task.wait (core func $task_wait))
(core instance $libc (instantiate $Libc))
(alias $libc "mem" (core memory $mem))
(alias $libc "realloc" (core func $realloc))
(canon lower $fetch async (memory $mem) (realloc $realloc) (core func $fetch'))
(canon task.return (result string) async (memory $mem) (realloc $realloc) (core func $task_return))
(canon task.wait async (memory $mem) (core func $task_wait))
(core instance $main (instantiate $Main (with "" (instance
(export "mem" (memory $mem))
(export "realloc" (func $realloc))
(export "fetch" (func $fetch'))
(export "task.return" (func $task_return))
(export "task.wait" (func $task_wait))
))))
(canon lift (core func $main "summarize") async
(canon lift (core func $main "summarize")
async (memory $mem) (realloc $realloc)
(func $summarize (param "urls" (list string)) (result string)))
(export "summarize" (func $summarize))
)
Expand Down Expand Up @@ -418,10 +431,16 @@ not externally-visible behavior.
```wat
(component
(import "fetch" (func $fetch (param "url" string) (result (list u8))))
(core module $Libc
(memory (export "mem") 1)
(func (export "realloc") (param i32 i32 i32 i32) (result i32) ...)
...
)
(core module $Main
(import "libc" "mem" (memory 1))
(import "libc" "realloc" (func (param i32 i32 i32 i32) (result i32)))
(import "" "fetch" (func $fetch (param i32 i32) (result i32)))
(import "" "task.return" (func $task_return (param i32)))
(import "" "task.wait" (func $wait (param i32) (result i32)))
(import "" "task.return" (func $task_return (param i32 i32)))
(func (export "summarize") (param i32 i32) (result i32)
...
loop
Expand All @@ -438,20 +457,24 @@ not externally-visible behavior.
return ;; wait for another subtask to make progress
end
...
call $task_return ;; return the string result
call $task_return ;; return the string result (pointer,length)
...
i32.const 0 ;; return zero to signal that this task is done
)
)
(canon lower $fetch async (core func $fetch'))
(canon task.return (core func $task_return))
(canon task.wait (core func $task_wait))
(core instance $libc (instantiate $Libc))
(alias $libc "mem" (core memory $mem))
(alias $libc "realloc" (core func $realloc))
(canon lower $fetch async (memory $mem) (realloc $realloc) (core func $fetch'))
(canon task.return (result string) async (memory $mem) (realloc $realloc) (core func $task_return))
(core instance $main (instantiate $Main (with "" (instance
(export "mem" (memory $mem))
(export "realloc" (func $realloc))
(export "fetch" (func $fetch'))
(export "task.return" (func $task_return))
(export "task.wait" (func $task_wait))
))))
(canon lift (core func $main "summarize") async (callback (core func $main "cb"))
(canon lift (core func $main "summarize")
async (callback (core func $main "cb")) (memory $mem) (realloc $realloc)
(func $summarize (param "urls" (list string)) (result string)))
(export "summarize" (func $summarize))
)
Expand Down
2 changes: 1 addition & 1 deletion design/mvp/Binary.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ canon ::= 0x00 0x00 f:<core:funcidx> opts:<opts> ft:<typeidx> => (canon lift
| 0x05 ft:<typeidx> => (canon thread.spawn ft (core func)) 🧵
| 0x06 => (canon thread.available_parallelism (core func)) 🧵
| 0x08 => (canon task.backpressure (core func)) 🔀
| 0x09 ft:<core:typeidx> => (canon task.return ft (core func)) 🔀
| 0x09 rs:<resultlist> opts:<opts> => (canon task.return rs opts (core func)) 🔀
| 0x0a async?:<async>? m:<core:memdix> => (canon task.wait async? (memory m) (core func)) 🔀
| 0x0b async?:<async>? m:<core:memidx> => (canon task.poll async? (memory m) (core func)) 🔀
| 0x0c async?:<async>? => (canon task.yield async? (core func)) 🔀
Expand Down
25 changes: 14 additions & 11 deletions design/mvp/CanonicalABI.md
Original file line number Diff line number Diff line change
Expand Up @@ -2878,28 +2878,31 @@ consume resources.

For a canonical definition:
```wasm
(canon task.return $ft (core func $f))
(canon task.return (result $t)? $opts (core func $f))
```
validation specifies:
* `$f` is given type `$ft`, which validation requires to be a (core) function type
* `$f` is given type `flatten_functype($opts, (func (param $t)?), 'lower')`

Calling `$f` invokes the following function which uses `Task.return_` to lift
and pass the results to the caller:
```python
async def canon_task_return(task, core_ft, flat_args):
async def canon_task_return(task, result_type, opts, flat_args):
trap_if(not task.inst.may_leave)
trap_if(task.opts.sync and not task.opts.always_task_return)
sync_opts = copy(task.opts)
sync_opts.sync = True
trap_if(core_ft != flatten_functype(sync_opts, FuncType(task.ft.results, []), 'lower'))
trap_if(result_type != task.ft.results)
trap_if(opts != task.opts)
task.return_(flat_args)
return []
```
An expected implementation of `task.return` would generate a core wasm function
for each lowering of an `async`-lifted export that performs the fused copy of
the results into the caller, storing the index of this function in the `Task`
structure and using `call_indirect` to perform the function-type-equality check
required here.
The `trap_if(result_type != task.ft.results)` guard ensures that, in a
component with multiple exported functions of different types, `task.return` is
not called with a mismatched result type (which, due to indirect control flow,
can in general only be caught dynamically).

The `trap_if(opts != task.opts)` guard ensures that the return value is lifted
the same way as the `canon lift` from which this `task.return` is returning.
This ensures that AOT fusion of `canon lift` and `canon lower` can generate
a thunk that is indirectly called by `task.return` after these guards.

### 🔀 `canon task.wait`

Expand Down
13 changes: 6 additions & 7 deletions design/mvp/Explainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ canon ::= ...
| (canon resource.drop <typeidx> async? (core func <id>?))
| (canon resource.rep <typeidx> (core func <id>?))
| (canon task.backpressure (core func <id>?)) 🔀
| (canon task.return <core:typeidx> (core func <id>?)) 🔀
| (canon task.return (result <valtype>)? <canonopt>* (core func <id>?)) 🔀
| (canon task.wait async? (memory <core:memidx>) (core func <id>?)) 🔀
| (canon task.poll async? (memory <core:memidx>) (core func <id>?)) 🔀
| (canon task.yield async? (core func <id>?)) 🔀
Expand Down Expand Up @@ -1543,12 +1543,11 @@ the Canonical ABI explainer.)

The `task.return` built-in takes as parameters the result values of the
currently-executing task. This built-in must be called exactly once per export
activation. The `canon task.return` definition takes the type index of a core
function type and produces a core function with exactly that type. When called,
the declared core function type is checked to match the lowered function type
of a component-level function taking the result types of the current task. (See
also [Returning](Async.md#returning) in the async explainer and
[`canon_task_return`] in the Canonical ABI explainer.)
activation. The `canon task.return` definition takes component-level return
type and the list of `canonopt` to be used to lift the return value. When
called, the declared return type and `canonopt`s are checked to exactly match
those of the current task. (See also [Returning](Async.md#returning) in the
async explainer and [`canon_task_return`] in the Canonical ABI explainer.)

###### 🔀 `task.wait`

Expand Down
7 changes: 3 additions & 4 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,12 +1855,11 @@ async def canon_task_backpressure(task, flat_args):

### 🔀 `canon task.return`

async def canon_task_return(task, core_ft, flat_args):
async def canon_task_return(task, result_type, opts, flat_args):
trap_if(not task.inst.may_leave)
trap_if(task.opts.sync and not task.opts.always_task_return)
sync_opts = copy(task.opts)
sync_opts.sync = True
trap_if(core_ft != flatten_functype(sync_opts, FuncType(task.ft.results, []), 'lower'))
trap_if(result_type != task.ft.results)
trap_if(opts != task.opts)
task.return_(flat_args)
return []

Expand Down
30 changes: 15 additions & 15 deletions design/mvp/canonical-abi/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ async def test_async_to_async():
eager_ft = FuncType([], [U8Type()])
async def core_eager_producer(task, args):
assert(len(args) == 0)
[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [43])
[] = await canon_task_return(task, [U8Type()], producer_opts, [43])
return []
eager_callee = partial(canon_lift, producer_opts, producer_inst, eager_ft, core_eager_producer)

Expand All @@ -537,7 +537,7 @@ async def core_toggle(task, args):
[] = await canon_task_backpressure(task, [1])
await task.on_block(fut1)
[] = await canon_task_backpressure(task, [0])
[] = await canon_task_return(task, CoreFuncType([],[]), [])
[] = await canon_task_return(task, [], producer_opts, [])
return []
toggle_callee = partial(canon_lift, producer_opts, producer_inst, toggle_ft, core_toggle)

Expand All @@ -547,7 +547,7 @@ async def core_blocking_producer(task, args):
[x] = args
assert(x == 83)
await task.on_block(fut2)
[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [44])
[] = await canon_task_return(task, [U8Type()], producer_opts, [44])
await task.on_block(fut3)
return []
blocking_callee = partial(canon_lift, producer_opts, producer_inst, blocking_ft, core_blocking_producer)
Expand Down Expand Up @@ -613,7 +613,7 @@ async def dtor(task, args):
assert(callidx == 2)
[] = await canon_subtask_drop(task, callidx)

[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [42])
[] = await canon_task_return(task, [U8Type()], consumer_opts, [42])
return []

ft = FuncType([BoolType()],[U8Type()])
Expand Down Expand Up @@ -641,7 +641,7 @@ async def test_async_callback():
async def core_producer_pre(fut, task, args):
assert(len(args) == 0)
await task.on_block(fut)
await canon_task_return(task, CoreFuncType([],[]), [])
await canon_task_return(task, [], producer_opts, [])
return []
fut1 = asyncio.Future()
core_producer1 = partial(core_producer_pre, fut1)
Expand Down Expand Up @@ -683,7 +683,7 @@ async def callback(task, args):
assert(args[2] == 2)
assert(args[3] == 0)
await canon_subtask_drop(task, 2)
[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [83])
[] = await canon_task_return(task, [U32Type()], opts, [83])
return [0]

consumer_inst = ComponentInstance()
Expand Down Expand Up @@ -761,7 +761,7 @@ async def consumer(task, args):

assert(await task.poll(sync = True) is None)

await canon_task_return(task, CoreFuncType(['i32'],[]), [83])
await canon_task_return(task, [U8Type()], consumer_opts, [83])
return []

consumer_inst = ComponentInstance()
Expand All @@ -786,7 +786,7 @@ async def test_async_backpressure():
producer1_done = False
async def producer1_core(task, args):
nonlocal producer1_done
await canon_task_return(task, CoreFuncType([],[]), [])
await canon_task_return(task, [], producer_opts, [])
await canon_task_backpressure(task, [1])
await task.on_block(fut)
await canon_task_backpressure(task, [0])
Expand All @@ -797,7 +797,7 @@ async def producer1_core(task, args):
async def producer2_core(task, args):
nonlocal producer2_done
assert(producer1_done == True)
await canon_task_return(task, CoreFuncType([],[]), [])
await canon_task_return(task, [], producer_opts, [])
producer2_done = True
return []

Expand Down Expand Up @@ -837,7 +837,7 @@ async def consumer(task, args):

assert(await task.poll(sync = False) is None)

await canon_task_return(task, CoreFuncType(['i32'],[]), [84])
await canon_task_return(task, [U8Type()], consumer_opts, [84])
return []

consumer_inst = ComponentInstance()
Expand All @@ -860,7 +860,7 @@ async def test_sync_using_wait():

async def core_hostcall_pre(fut, task, args):
await task.on_block(fut)
[] = await canon_task_return(task, CoreFuncType([],[]), [])
[] = await canon_task_return(task, [], hostcall_opts, [])
return []
fut1 = asyncio.Future()
core_hostcall1 = partial(core_hostcall_pre, fut1)
Expand Down Expand Up @@ -1048,7 +1048,7 @@ async def core_func(task, args):
rsi1 = args[0]
assert(rsi1 == 1)
[wsi1] = await canon_stream_new(U8Type(), task)
[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1])
[] = await canon_task_return(task, [StreamType(U8Type())], opts, [wsi1])
[ret] = await canon_stream_read(U8Type(), opts, task, rsi1, 0, 4)
assert(ret == 4)
assert(mem[0:4] == b'\x01\x02\x03\x04')
Expand Down Expand Up @@ -1122,7 +1122,7 @@ async def core_func(task, args):
[rsi1] = args
assert(rsi1 == 1)
[wsi1] = await canon_stream_new(U8Type(), task)
[] = await canon_task_return(task, CoreFuncType(['i32'],[]), [wsi1])
[] = await canon_task_return(task, [StreamType(U8Type())], opts, [wsi1])
[ret] = await canon_stream_read(U8Type(), opts, task, rsi1, 0, 4)
assert(ret == definitions.BLOCKED)
src_stream.write([1,2,3,4])
Expand Down Expand Up @@ -1328,7 +1328,7 @@ async def test_wasm_to_wasm_stream():
async def core_func1(task, args):
assert(not args)
[wsi] = await canon_stream_new(U8Type(), task)
[] = await canon_task_return(task, CoreFuncType(['i32'], []), [wsi])
[] = await canon_task_return(task, [StreamType(U8Type())], opts1, [wsi])

await task.on_block(fut1)

Expand Down Expand Up @@ -1367,7 +1367,7 @@ async def core_func1(task, args):
ft2 = FuncType([], [])
async def core_func2(task, args):
assert(not args)
[] = await canon_task_return(task, CoreFuncType([], []), [])
[] = await canon_task_return(task, [], opts2, [])

retp = 0
[ret] = await canon_lower(opts2, ft1, func1, task, [retp])
Expand Down
Loading