Skip to content

Commit 348e1bb

Browse files
committed
Add derived_types handling : generating the pt on the data, and the pt on the field api object
1 parent cacb023 commit 348e1bb

File tree

3 files changed

+139
-4
lines changed

3 files changed

+139
-4
lines changed

transformations/tests/test_parallel_routine_dispatch.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,76 @@ def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
100100
assert len(conditional) == 5
101101
for cond in conditional:
102102
assert fgen(cond) in field_delete
103-
breakpoint()
103+
104+
@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
105+
def test_parallel_routine_dispatch_derived_dcl(here, frontend):
106+
107+
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
108+
routine = source['dispatch_routine']
109+
110+
transformation = ParallelRoutineDispatchTransformation()
111+
transformation.apply(source['dispatch_routine'])
112+
113+
dcls = [fgen(dcl) for dcl in routine.spec.body[-13:-1]]
114+
115+
test_dcls=["REAL(KIND=JPRB), POINTER :: Z_YDVARS_U_T0(:, :, :)",
116+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DM(:, :, :)",
117+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GELAM_T0(:, :)",
118+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_T0(:, :, :)",
119+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DL(:, :, :)",
120+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_V_T0(:, :, :)",
121+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GEMU_T0(:, :)",
122+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_T0(:, :, :)",
123+
"REAL(KIND=JPRB), POINTER :: Z_YDCPG_PHY0_XYB_RDELP(:, :, :)",
124+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DM(:, :, :)",
125+
"REAL(KIND=JPRB), POINTER :: Z_YDCPG_DYN0_CTY_EVEL(:, :, :)",
126+
"REAL(KIND=JPRB), POINTER :: Z_YDMF_PHYS_SURF_GSD_VF_PZ0F(:, :)",
127+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DL(:, :, :)"]
128+
for dcl in dcls:
129+
assert dcl in test_dcls
130+
131+
@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
132+
def test_parallel_routine_dispatch_derived_var(here, frontend):
133+
134+
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
135+
routine = source['dispatch_routine']
136+
137+
transformation = ParallelRoutineDispatchTransformation()
138+
transformation.apply(source['dispatch_routine'])
139+
140+
141+
## test_dcls=["REAL(KIND=JPRB), POINTER :: Z_YDVARS_U_T0(:, :, :)",
142+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DM(:, :, :)",
143+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GELAM_T0(:, :)",
144+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_T0(:, :, :)",
145+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DL(:, :, :)",
146+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_V_T0(:, :, :)",
147+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GEMU_T0(:, :)",
148+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_T0(:, :, :)",
149+
##"REAL(KIND=JPRB), POINTER :: Z_YDCPG_PHY0_XYB_RDELP(:, :, :)",
150+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DM(:, :, :)",
151+
##"REAL(KIND=JPRB), POINTER :: Z_YDCPG_DYN0_CTY_EVEL(:, :, :)",
152+
##"REAL(KIND=JPRB), POINTER :: Z_YDMF_PHYS_SURF_GSD_VF_PZ0F(:, :)",
153+
##"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DL(:, :, :)"]
154+
test_map = {
155+
"YDVARS%GEOMETRY%GEMU%T0" : ["YDVARS%GEOMETRY%GEMU%FT0", "Z_YDVARS_GEOMETRY_GEMU_T0"],
156+
"YDVARS%GEOMETRY%GELAM%T0" : ["YDVARS%GEOMETRY%GELAM%FT0", "Z_YDVARS_GEOMETRY_GELAM_T0"],
157+
"YDVARS%U%T0" : ["YDVARS%U%FT0", "Z_YDVARS_U_T0"],
158+
"YDVARS%V%T0" : ["YDVARS%V%FT0", "Z_YDVARS_V_T0"],
159+
"YDVARS%Q%T0" : ["YDVARS%Q%FT0", "Z_YDVARS_Q_T0"],
160+
"YDVARS%Q%DM" : ["YDVARS%Q%FDM", "Z_YDVARS_Q_DM"],
161+
"YDVARS%Q%DL" : ["YDVARS%Q%FDL", "Z_YDVARS_Q_DL"],
162+
"YDVARS%CVGQ%T0" : ["YDVARS%CVGQ%FT0", "Z_YDVARS_CVGQ_T0"],
163+
"YDVARS%CVGQ%DM" : ["YDVARS%CVGQ%FDM", "Z_YDVARS_CVGQ_DM"],
164+
"YDVARS%CVGQ%DL" : ["YDVARS%CVGQ%FDL", "Z_YDVARS_CVGQ_DL"],
165+
"YDCPG_PHY0%XYB%RDELP" : ["YDCPG_PHY0%XYB%F_RDELP", "Z_YDCPG_PHY0_XYB_RDELP"],
166+
"YDCPG_DYN0%CTY%EVEL" : ["YDCPG_DYN0%CTY%F_EVEL", "Z_YDCPG_DYN0_CTY_EVEL"],
167+
"YDMF_PHYS_SURF%GSD_VF%PZ0F" : ["YDMF_PHYS_SURF%GSD_VF%F_Z0F", "Z_YDMF_PHYS_SURF_GSD_VF_PZ0F"]
168+
}
169+
for var_name in transformation.routine_map_derived:
170+
value = transformation.routine_map_derived[var_name]
171+
field_ptr = value[0]
172+
ptr = value[1]
173+
174+
assert test_map[var_name][0] == field_ptr.name
175+
assert test_map[var_name][1] == ptr.name
1.13 MB
Binary file not shown.

transformations/transformations/parallel_routine_dispatch.py

+66-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
from loki.transform import Transformation
1414
from loki import (
1515
FindVariables, DerivedType, SymbolAttributes,
16-
Array, single_variable_declaration, Transformer
16+
Array, single_variable_declaration, Transformer,
17+
BasicType
1718
)
19+
import pickle
20+
import os
1821

1922
__all__ = ['ParallelRoutineDispatchTransformation']
2023

@@ -26,11 +29,19 @@ def __init__(self):
2629
"KLON", "YDCPG_OPTS%KLON", "YDGEOMETRY%YRDIM%NPROMA",
2730
"KPROMA", "YDDIM%NPROMA", "NPROMA"
2831
]
32+
#TODO : do smthg for opening field_index.pkl
33+
with open(os.getcwd()+"/transformations/transformations/field_index.pkl", 'rb') as fp:
34+
self.map_index = pickle.load(fp)
2935
# CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
3036
self.new_calls = []
3137
# IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
3238
self.delete_calls = []
33-
self.routine_map_temp = {}
39+
# map[name] = [field_ptr, ptr]
40+
# where :
41+
# field_ptr : pointer on field api object
42+
# ptr : pointer to the data
43+
self.routine_map_temp = {}
44+
self.routine_map_derived = {}
3445

3546
def transform_subroutine(self, routine, **kwargs):
3647
with pragma_regions_attached(routine):
@@ -40,6 +51,7 @@ def transform_subroutine(self, routine, **kwargs):
4051
single_variable_declaration(routine)
4152
self.add_temp(routine)
4253
self.add_field(routine)
54+
self.add_derived(routine)
4355
#call add_arrays etc...
4456

4557
def process_parallel_region(self, routine, region):
@@ -61,11 +73,15 @@ def process_parallel_region(self, routine, region):
6173
region.append(dr_hook_calls[1])
6274

6375
region_map_temp= self.decl_local_array(routine, region)
76+
region_map_derived= self.decl_derived_types(routine, region)
6477

6578
for var_name in region_map_temp:
6679
if var_name not in self.routine_map_temp:
6780
self.routine_map_temp[var_name]=region_map_temp[var_name]
6881

82+
for var_name in region_map_derived:
83+
if var_name not in self.routine_map_derived:
84+
self.routine_map_derived[var_name]=region_map_derived[var_name]
6985

7086

7187
@staticmethod
@@ -178,4 +194,51 @@ def add_field(self, routine):
178194
routine, cdname='DELETE_TEMPORARIES',
179195
handle=sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine)
180196
)
181-
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))
197+
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))
198+
199+
def decl_derived_types(self, routine, region):
200+
region_map_derived = {}
201+
derived = [var for var in FindVariables().visit(region) if var.name_parts[0] in routine.arguments]
202+
for var in derived :
203+
204+
key = f"{routine.variable_map[var.name_parts[0]].type.dtype.name}%{'%'.join(var.name_parts[1:])}"
205+
if key in self.map_index:
206+
value = self.map_index[key]
207+
# Creating the pointer on the data : YL_A
208+
data_name = f"Z_{var.name.replace('%', '_')}"
209+
if "REAL" and "JPRB" in value[0]:
210+
data_type = SymbolAttributes(
211+
dtype=BasicType.REAL, kind=routine.symbol_map['JPRB'],
212+
pointer=True
213+
)
214+
data_dim = value[2] + 1
215+
data_shape = (sym.RangeIndex((None, None)),) * data_dim
216+
ptr_var = sym.Variable(name=data_name, type=data_type, dimensions=data_shape, scope=routine)
217+
218+
else:
219+
raise NotImplementedError("This type isn't implemented yet")
220+
221+
# Creating the pointer on the field api object : YL%FA, YL%F_A...
222+
if routine.variable_map[var.name_parts[0]].type.dtype.name=="MF_PHYS_SURF_TYPE":
223+
# YL%PA becomes YL%F_A
224+
field_name = f"{'%'.join(var.name_parts[:-1])}%F_{var.name_parts[-1][1:]}"
225+
elif routine.variable_map[var.name_parts[0]].type.dtype.name=="FIELD_VARIABLES":
226+
# YL%A becomes YL%FA
227+
field_name = f"{'%'.join(var.name_parts[:-1])}%F{var.name_parts[-1]}"
228+
if var.name_parts[-1]=="P": #YL%FP = YL%FT0
229+
field_name = f"{field_name[-1]}T0"
230+
else:
231+
# YL%A becomes YL%F_A
232+
field_name = f"{'%'.join(var.name_parts[:-1])}%F_{var.name_parts[-1]}"
233+
field_ptr_var = var.clone(name=field_name)
234+
region_map_derived[var.name] = [field_ptr_var, ptr_var]
235+
return(region_map_derived)
236+
237+
def add_derived(self, routine):
238+
ptr_var=()
239+
for value in self.routine_map_derived.values():
240+
dcl = ir.VariableDeclaration(
241+
symbols=(value[1],)
242+
)
243+
ptr_var += (dcl,)
244+
routine.spec.append(ptr_var)

0 commit comments

Comments
 (0)