13
13
from loki .transform import Transformation
14
14
from loki import (
15
15
FindVariables , DerivedType , SymbolAttributes ,
16
- Array , single_variable_declaration , Transformer
16
+ Array , single_variable_declaration , Transformer ,
17
+ BasicType
17
18
)
19
+ import pickle
20
+ import os
18
21
19
22
__all__ = ['ParallelRoutineDispatchTransformation' ]
20
23
@@ -26,11 +29,15 @@ def __init__(self):
26
29
"KLON" , "YDCPG_OPTS%KLON" , "YDGEOMETRY%YRDIM%NPROMA" ,
27
30
"KPROMA" , "YDDIM%NPROMA" , "NPROMA"
28
31
]
32
+ #TODO : do smthg for opening field_index.pkl
33
+ with open (os .getcwd ()+ "/transformations/transformations/field_index.pkl" , 'rb' ) as fp :
34
+ self .map_index = pickle .load (fp )
29
35
# CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
30
36
self .new_calls = []
31
37
# IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
32
38
self .delete_calls = []
33
39
self .routine_map_temp = {}
40
+ self .routine_map_derived = {}
34
41
35
42
def transform_subroutine (self , routine , ** kwargs ):
36
43
with pragma_regions_attached (routine ):
@@ -40,6 +47,7 @@ def transform_subroutine(self, routine, **kwargs):
40
47
single_variable_declaration (routine )
41
48
self .add_temp (routine )
42
49
self .add_field (routine )
50
+ self .add_derived (routine )
43
51
#call add_arrays etc...
44
52
45
53
def process_parallel_region (self , routine , region ):
@@ -61,11 +69,15 @@ def process_parallel_region(self, routine, region):
61
69
region .append (dr_hook_calls [1 ])
62
70
63
71
region_map_temp = self .decl_local_array (routine , region )
72
+ region_map_derived = self .decl_derived_types (routine , region )
64
73
65
74
for var_name in region_map_temp :
66
75
if var_name not in self .routine_map_temp :
67
76
self .routine_map_temp [var_name ]= region_map_temp [var_name ]
68
77
78
+ for var_name in region_map_derived :
79
+ if var_name not in self .routine_map_derived :
80
+ self .routine_map_derived [var_name ]= region_map_derived [var_name ]
69
81
70
82
71
83
@staticmethod
@@ -178,4 +190,51 @@ def add_field(self, routine):
178
190
routine , cdname = 'DELETE_TEMPORARIES' ,
179
191
handle = sym .Variable (name = 'ZHOOK_HANDLE_FIELD_API' , scope = routine )
180
192
)
181
- routine .body .insert (- 2 ,(dr_hook_calls [0 ], ir .Comment (text = '' ), * self .delete_calls , dr_hook_calls [1 ]))
193
+ routine .body .insert (- 2 ,(dr_hook_calls [0 ], ir .Comment (text = '' ), * self .delete_calls , dr_hook_calls [1 ]))
194
+
195
+ def decl_derived_types (self , routine , region ):
196
+ region_map_derived = {}
197
+ derived = [var for var in FindVariables ().visit (region ) if var .name_parts [0 ] in routine .arguments ]
198
+ for var in derived :
199
+
200
+ key = f"{ routine .variable_map [var .name_parts [0 ]].type .dtype .name } %{ '%' .join (var .name_parts [1 :])} "
201
+ if key in self .map_index :
202
+ value = self .map_index [key ]
203
+ # Creating the pointer on the data : YL_A
204
+ data_name = f"Z_{ var .name .replace ('%' , '_' )} "
205
+ if "REAL" and "JPRB" in value [0 ]:
206
+ data_type = SymbolAttributes (
207
+ dtype = BasicType .REAL , kind = routine .symbol_map ['JPRB' ],
208
+ pointer = True
209
+ )
210
+ data_dim = value [2 ] + 1
211
+ data_shape = (sym .RangeIndex ((None , None )),) * data_dim
212
+ ptr_var = sym .Variable (name = data_name , type = data_type , dimensions = data_shape , scope = routine )
213
+
214
+ else :
215
+ raise NotImplementedError ("This type isn't implemented yet" )
216
+
217
+ # Creating the pointer on the field api object : YL%FA, YL%F_A...
218
+ if routine .variable_map [var .name_parts [0 ]].type .dtype .name == "MF_PHYS_SURF_TYPE" :
219
+ # YL%PA becomes YL%F_A
220
+ field_name = f"{ '%' .join (var .name_parts [:- 1 ])} %F_{ var .name_parts [- 1 ][1 :]} "
221
+ elif routine .variable_map [var .name_parts [0 ]].type .dtype .name == "FIELD_VARIABLES" :
222
+ # YL%A becomes YL%FA
223
+ field_name = f"{ '%' .join (var .name_parts [:- 1 ])} %F{ var .name_parts [- 1 ]} "
224
+ if var .name_parts [- 1 ]== "P" : #YL%FP = YL%FT0
225
+ field_name = f"{ field_name [- 1 ]} T0"
226
+ else :
227
+ # YL%A becomes YL%F_A
228
+ field_name = f"{ '%' .join (var .name_parts [:- 1 ])} %F_{ var .name_parts [- 1 ]} "
229
+ field_ptr_var = var .clone (name = field_name )
230
+ region_map_derived [var .name ] = [field_ptr_var , ptr_var ]
231
+ return (region_map_derived )
232
+
233
+ def add_derived (self , routine ):
234
+ ptr_var = ()
235
+ for value in self .routine_map_derived .values ():
236
+ dcl = ir .VariableDeclaration (
237
+ symbols = (value [1 ],)
238
+ )
239
+ ptr_var += (dcl ,)
240
+ routine .spec .append (ptr_var )
0 commit comments