From 65b0b92c0324acf3cfa3b864747c86fc646f1db9 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 16 Aug 2021 19:17:58 -0500 Subject: [PATCH] PytatoPyOpenCLArrayContext: respect arraycontext.loopy's options --- arraycontext/impl/pytato/__init__.py | 5 ++--- arraycontext/impl/pytato/compile.py | 7 ++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index beaebc4b..88b05e6b 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -45,6 +45,7 @@ import numpy as np from typing import Any, Callable, Union, Sequence, TYPE_CHECKING from pytools.tag import Tag +from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS if TYPE_CHECKING: import pytato @@ -121,7 +122,6 @@ def call_loopy(self, program, **kwargs): def freeze(self, array): import pytato as pt import pyopencl.array as cla - import loopy as lp if isinstance(array, cla.Array): return array.with_queue(None) @@ -150,8 +150,7 @@ def freeze(self, array): pt_prg = self._freeze_prg_cache[normalized_expr] except KeyError: pt_prg = pt.generate_loopy(self.transform_dag(normalized_expr), - options=lp.Options(return_dict=True, - no_numpy=True), + options=_DEFAULT_LOOPY_OPTIONS, cl_device=self.queue.device) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index faed2cfe..46fed323 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -31,6 +31,7 @@ from arraycontext import PytatoPyOpenCLArrayContext from arraycontext.container.traversal import (rec_keyed_map_array_container, is_array_container) +from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS import numpy as np from typing import Any, Callable, Tuple, Dict, Mapping @@ -225,15 +226,11 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, outputs) - import loopy as lp - pt_dict_of_named_arrays = self.actx.transform_dag( pt.make_dict_of_named_arrays(dict_of_named_arrays)) pytato_program = pt.generate_loopy(pt_dict_of_named_arrays, - options=lp.Options( - return_dict=True, - no_numpy=True), + options=_DEFAULT_LOOPY_OPTIONS, cl_device=self.actx.queue.device) assert isinstance(pytato_program, BoundPyOpenCLProgram)