Skip to content

Commit

Permalink
Annotation for mildly typed cgen
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 28, 2024
1 parent 1af4523 commit d71475e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 26 deletions.
22 changes: 11 additions & 11 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def get_function_definition(
tv.initializer is not None):
assert tv.read_only

decl = self.wrap_global_constant(
decl: Generable = self.wrap_global_constant(
self.get_temporary_var_declarator(codegen_state, tv))

if tv.initializer is not None:
Expand Down Expand Up @@ -850,20 +850,20 @@ def get_function_declaration(

from cgen import FunctionDeclaration, Value

name = codegen_result.current_program(codegen_state).name
name_str = codegen_result.current_program(codegen_state).name
if self.target.fortran_abi:
name += "_"
name_str += "_"

if codegen_state.is_entrypoint:
name = Value("void", name)
name: Declarator = Value("void", name_str)

# subkernel launches occur only as part of entrypoint kernels for now
from loopy.schedule.tools import get_subkernel_arg_info
skai = get_subkernel_arg_info(kernel, subkernel_name)
passed_names = skai.passed_names
written_names = skai.written_names
else:
name = Value("static void", name)
name = Value("static void", name_str)
passed_names = [arg.name for arg in kernel.args]
written_names = kernel.get_written_variables()

Expand Down Expand Up @@ -892,11 +892,11 @@ def emit_temp_var_decl_for_tv_with_base_storage(self,
assert isinstance(tv.address_space, AddressSpace)
ecm = codegen_state.expression_to_code_mapper

cast_decl = POD(self, tv.dtype, "")
temp_var_decl = POD(self, tv.dtype, tv.name)
cast_decl: Declarator = POD(self, tv.dtype, "")
temp_var_decl: Declarator = POD(self, tv.dtype, tv.name)

if tv._base_storage_access_may_be_aliasing:
ptrtype = _ConstPointer
ptrtype: type[Pointer] = _ConstPointer
else:
# The 'restrict' part of this is a complete lie--of course
# all these temporaries are aliased. But we're promising to
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def wrap_global_constant(self, decl: Declarator) -> Declarator:

def get_value_arg_declaraotor(
self, name: str, dtype: LoopyType, is_written: bool) -> Declarator:
result = POD(self, dtype, name)
result: Declarator = POD(self, dtype, name)

if not is_written:
from cgen import Const
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def get_array_base_declarator(self, ary: ArrayBase) -> Declarator:
def get_array_arg_declarator(
self, arg: ArrayArg, is_written: bool) -> Declarator:
from cgen import RestrictPointer
arg_decl = RestrictPointer(
arg_decl: Declarator = RestrictPointer(
self.wrap_decl_for_address_space(
self.get_array_base_declarator(arg), arg.address_space))

Expand All @@ -1070,7 +1070,7 @@ def get_temporary_arg_decl(
from cgen import RestrictPointer
assert temp_var.address_space is not auto

arg_decl = RestrictPointer(
arg_decl: Declarator = RestrictPointer(
self.wrap_decl_for_address_space(
self.get_array_base_declarator(temp_var),
cast(AddressSpace, temp_var.address_space)))
Expand Down
10 changes: 5 additions & 5 deletions loopy/target/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np

from cgen import Const, Declarator, Generable
from cgen import Const, Declarator, Generable, Pointer
from pymbolic import var
from pytools import memoize_method

Expand Down Expand Up @@ -448,7 +448,7 @@ def get_array_base_declarator(self, ary: ArrayBase) -> Declarator:
def get_array_arg_declarator(
self, arg: ArrayArg, is_written: bool) -> Declarator:
from cgen.cuda import CudaRestrictPointer
arg_decl = CudaRestrictPointer(
arg_decl: Declarator = CudaRestrictPointer(
self.get_array_base_declarator(arg))

if not is_written:
Expand Down Expand Up @@ -477,11 +477,11 @@ def emit_temp_var_decl_for_tv_with_base_storage(self,
assert tv.base_storage is not None
ecm = codegen_state.expression_to_code_mapper

cast_decl = POD(self, tv.dtype, "")
temp_var_decl = POD(self, tv.dtype, tv.name)
cast_decl: Declarator = POD(self, tv.dtype, "")
temp_var_decl: Declarator = POD(self, tv.dtype, tv.name)

if tv._base_storage_access_may_be_aliasing:
ptrtype = _ConstPointer
ptrtype: type[Pointer] = _ConstPointer
else:
# The 'restrict' part of this is a complete lie--of course
# all these temporaries are aliased. But we're promising to
Expand Down
4 changes: 2 additions & 2 deletions loopy/target/ispc.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_function_declaration(
for arg_name in passed_names]

if codegen_state.is_generating_device_code:
result = ISPCTask(
result: Declarator = ISPCTask(
FunctionDeclaration(
Value("void", name),
arg_decls))
Expand Down Expand Up @@ -323,7 +323,7 @@ def get_array_arg_declarator(
self, arg: ArrayArg, is_written: bool) -> Declarator:
# FIXME restrict?
from cgen.ispc import ISPCUniform, ISPCUniformPointer
decl = ISPCUniform(
decl: Declarator = ISPCUniform(
ISPCUniformPointer(self.get_array_base_declarator(arg)))

if not is_written:
Expand Down
15 changes: 8 additions & 7 deletions loopy/target/pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Block,
Collection,
Const,
Declarator,
FunctionBody,
Generable,
Initializer,
Expand Down Expand Up @@ -1027,7 +1028,7 @@ def get_function_definition(
self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult,
schedule_index: int, function_decl: Generable, function_body: Generable,
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
) -> Generable:
assert isinstance(function_body, Block)
kernel = codegen_state.kernel
assert kernel.linearization is not None
Expand Down Expand Up @@ -1055,7 +1056,7 @@ def get_function_definition(
tv.initializer is not None):
assert tv.read_only

decl = self.wrap_global_constant(
decl: Generable = self.wrap_global_constant(
self.get_temporary_var_declarator(codegen_state, tv))

if tv.initializer is not None:
Expand Down Expand Up @@ -1109,14 +1110,14 @@ def get_function_declaration(

from cgen import FunctionDeclaration, Struct, Value

name = codegen_result.current_program(codegen_state).name
name_str = codegen_result.current_program(codegen_state).name
if self.target.fortran_abi:
name += "_"
name_str += "_"

from loopy.target.c import FunctionDeclarationWrapper

if codegen_state.is_entrypoint:
name = Value("void", name)
name = Value("void", name_str)

# subkernel launches occur only as part of entrypoint kernels for now
from loopy.schedule.tools import get_subkernel_arg_info
Expand Down Expand Up @@ -1146,7 +1147,7 @@ def get_function_declaration(
(f"declare-{arg_overflow_struct_name}",
str(arg_overflow_struct))
] if struct_overflow_arg_names else []
arg_struct_args = [CLGlobal(Const(Pointer(Value(
arg_struct_args: list[Declarator] = [CLGlobal(Const(Pointer(Value(
f"struct {arg_overflow_struct_name}",
"_lpy_overflow_args"))))]
else:
Expand All @@ -1165,7 +1166,7 @@ def get_function_declaration(
+ arg_struct_args
)))
else:
name = Value("static void", name)
name = Value("static void", name_str)
passed_names = [arg.name for arg in kernel.args]
written_names = kernel.get_written_variables()

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ exclude = [
module = [
"islpy.*",
"pymbolic.*",
"cgen.*",
"genpy.*",
"pyopencl.*",
"colorama.*",
Expand Down

0 comments on commit d71475e

Please sign in to comment.