Skip to content

Commit

Permalink
[cpu] Add test_annotations.py to CI (#81)
Browse files Browse the repository at this point in the history
The only test that needed fixing was `test_unknown_annotations`, where we were
generating invalid code for the launcher. In particular, when
`kernel_fn_args` was empty, we would get the following error:

```
/var/folders/_z/88s630fd3d9fx72mbmx90qvw0000gn/T/tmpy481mz0l/main.cpp:37:29: error: expected ';' before '(' token
   37 | using kernel_ptr_t = void(*)(, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
      |                             ^
      |                             ;
```
  • Loading branch information
int3 authored Jul 30, 2024
1 parent f2352e7 commit ed82c33
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,6 @@ jobs:
run: |
python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu
python -m pytest -s -n 32 --device cpu \
python/test/unit/language/test_annotations.py \
python/test/unit/language/test_block_pointer.py \
python/test/unit/language/test_conversions.py
7 changes: 3 additions & 4 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,10 @@ def format_of(ty):

args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
format = "iiiOKOOOO" + args_format
arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items())
kernel_fn_args = [i for i in signature.keys() if i not in constants]
kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else ''
kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", "
if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t"
kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args)
kernel_fn_arg_types = ', '.join([f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args] + ["uint32_t"] * 6)

# generate glue code
src = f"""
Expand Down

0 comments on commit ed82c33

Please sign in to comment.