Skip to content

Commit

Permalink
Work around long compile times for complex fuzzed expressions on gcc …
Browse files Browse the repository at this point in the history
…11+ (gh-686)
  • Loading branch information
inducer committed Oct 3, 2022
1 parent 6f6a762 commit 8a2b06b
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions test/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ def assert_parse_roundtrip(expr):
assert expr == parsed_expr


@pytest.mark.parametrize("target", [lp.PyOpenCLTarget, lp.ExecutableCTarget])
@pytest.mark.parametrize("target_cls", [lp.PyOpenCLTarget, lp.ExecutableCTarget])
@pytest.mark.parametrize("random_seed", [0, 1, 2, 3, 4, 5])
@pytest.mark.parametrize("expr_type", ["int", "int_nonneg", "real", "complex"])
def test_fuzz_expression_code_gen(ctx_factory, expr_type, random_seed, target):
def test_fuzz_expression_code_gen(ctx_factory, expr_type, random_seed, target_cls):
from pymbolic import evaluate

def get_numpy_type(x):
Expand Down Expand Up @@ -290,7 +290,7 @@ def get_numpy_type(x):

var_name = "expr%d" % i

print(expr)
# print(expr)
#assert_parse_roundtrip(expr)

if expr_type in ["int", "int_nonneg"]:
Expand All @@ -299,7 +299,7 @@ def get_numpy_type(x):
var_values,
lbound=result_type_iinfo.min,
ubound=result_type_iinfo.max)
print(expr)
# print(expr)
try:
ref_values[var_name] = bceval_mapper(expr)
except BoundsCheckError:
Expand Down Expand Up @@ -330,8 +330,27 @@ def get_numpy_type(x):
if expr_type == "int_nonneg":
var_names.extend(var_values)

if issubclass(target_cls, lp.ExecutableCTarget):
# https://github.com/inducer/loopy/issues/686
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107127

from shutil import which
gcc_10 = which("gcc-10")
if gcc_10 is not None:
from loopy.target.c.c_execution import CCompiler
target = target_cls(compiler=CCompiler(cc=gcc_10))
else:
from warnings import warn
warn("Using default C compiler because gcc-10 was not found. "
"These tests may take a long time, because of "
"https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107127.")
target = target_cls()

else:
target = target_cls()

knl = lp.make_kernel("{ : }", instructions, data, seq_dependencies=True,
target=target())
target=target)

import islpy as isl
knl = lp.assume(knl, isl.BasicSet(
Expand All @@ -341,14 +360,14 @@ def get_numpy_type(x):
" and ".join("%s >= 0" % name for name in var_names))))

knl = lp.set_options(knl, return_dict=True)
print(knl)
# print(knl)

if target == lp.PyOpenCLTarget:
if type(target) is lp.PyOpenCLTarget:
cl_ctx = ctx_factory()
knl = lp.set_options(knl, write_code=True)
with cl.CommandQueue(cl_ctx) as queue:
evt, lp_values = knl(queue, out_host=True)
elif target == lp.ExecutableCTarget:
elif type(target) is lp.ExecutableCTarget:
evt, lp_values = knl()
else:
raise NotImplementedError("unsupported target")
Expand Down

0 comments on commit 8a2b06b

Please sign in to comment.