1+ from __future__ import annotations
2+
3+
14"""
25.. currentmodule:: arraycontext
36
3033THE SOFTWARE.
3134"""
3235
33- from typing import Any , Dict
36+ from typing import Any
3437
3538import numpy as np
3639
3942
4043from arraycontext .container .traversal import rec_map_array_container , with_array_context
4144from arraycontext .context import (
45+ Array ,
4246 ArrayContext ,
4347 ArrayOrContainerOrScalar ,
4448 ArrayOrContainerOrScalarT ,
@@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext):
6266
6367 .. automethod:: __init__
6468 """
69+
70+ _loopy_transform_cache : dict [lp .TranslationUnit , lp .ExecutorBase ]
71+
6572 def __init__ (self ) -> None :
6673 super ().__init__ ()
67- self ._loopy_transform_cache : \
68- Dict [lp .TranslationUnit , lp .TranslationUnit ] = {}
74+ self ._loopy_transform_cache = {}
6975
7076 array_types = (NumpyNonObjectArray ,)
7177
@@ -88,17 +94,18 @@ def to_numpy(self,
8894 ) -> NumpyOrContainerOrScalar :
8995 return array
9096
91- def call_loopy (self , t_unit , ** kwargs ):
97+ def call_loopy (
98+ self ,
99+ t_unit : lp .TranslationUnit , ** kwargs : Any
100+ ) -> dict [str , Array ]:
92101 t_unit = t_unit .copy (target = lp .ExecutableCTarget ())
93102 try :
94- t_unit = self ._loopy_transform_cache [t_unit ]
103+ executor = self ._loopy_transform_cache [t_unit ]
95104 except KeyError :
96- orig_t_unit = t_unit
97- t_unit = self .transform_loopy_program (t_unit )
98- self ._loopy_transform_cache [orig_t_unit ] = t_unit
99- del orig_t_unit
105+ executor = self .transform_loopy_program (t_unit ).executor ()
106+ self ._loopy_transform_cache [t_unit ] = executor
100107
101- _ , result = t_unit (** kwargs )
108+ _ , result = executor (** kwargs )
102109
103110 return result
104111
0 commit comments