Skip to content

Commit d1ece0e

Browse files
committed
Add derived_types handling : generating the pt on the data, and the pt on the field api object
1 parent cacb023 commit d1ece0e

File tree

3 files changed

+88
-3
lines changed

3 files changed

+88
-3
lines changed

transformations/tests/test_parallel_routine_dispatch.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,30 @@ def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
100100
assert len(conditional) == 5
101101
for cond in conditional:
102102
assert fgen(cond) in field_delete
103-
breakpoint()
103+
104+
@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
105+
def test_parallel_routine_dispatch_derived(here, frontend):
106+
107+
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
108+
routine = source['dispatch_routine']
109+
110+
transformation = ParallelRoutineDispatchTransformation()
111+
transformation.apply(source['dispatch_routine'])
112+
113+
dcls = [fgen(dcl) for dcl in routine.spec.body[-13:-1]]
114+
115+
test_dcls=["REAL(KIND=JPRB), POINTER :: Z_YDVARS_U_T0(:, :, :)",
116+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DM(:, :, :)",
117+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GELAM_T0(:, :)",
118+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_T0(:, :, :)",
119+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DL(:, :, :)",
120+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_V_T0(:, :, :)",
121+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GEMU_T0(:, :)",
122+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_T0(:, :, :)",
123+
"REAL(KIND=JPRB), POINTER :: Z_YDCPG_PHY0_XYB_RDELP(:, :, :)",
124+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DM(:, :, :)",
125+
"REAL(KIND=JPRB), POINTER :: Z_YDCPG_DYN0_CTY_EVEL(:, :, :)",
126+
"REAL(KIND=JPRB), POINTER :: Z_YDMF_PHYS_SURF_GSD_VF_PZ0F(:, :)",
127+
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DL(:, :, :)"]
128+
for dcl in dcls:
129+
assert dcl in test_dcls
1.13 MB
Binary file not shown.

transformations/transformations/parallel_routine_dispatch.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
from loki.transform import Transformation
1414
from loki import (
1515
FindVariables, DerivedType, SymbolAttributes,
16-
Array, single_variable_declaration, Transformer
16+
Array, single_variable_declaration, Transformer,
17+
BasicType
1718
)
19+
import pickle
20+
import os
1821

1922
__all__ = ['ParallelRoutineDispatchTransformation']
2023

@@ -26,11 +29,15 @@ def __init__(self):
2629
"KLON", "YDCPG_OPTS%KLON", "YDGEOMETRY%YRDIM%NPROMA",
2730
"KPROMA", "YDDIM%NPROMA", "NPROMA"
2831
]
32+
#TODO : do smthg for opening field_index.pkl
33+
with open(os.getcwd()+"/transformations/transformations/field_index.pkl", 'rb') as fp:
34+
self.map_index = pickle.load(fp)
2935
# CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
3036
self.new_calls = []
3137
# IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
3238
self.delete_calls = []
3339
self.routine_map_temp = {}
40+
self.routine_map_derived = {}
3441

3542
def transform_subroutine(self, routine, **kwargs):
3643
with pragma_regions_attached(routine):
@@ -40,6 +47,7 @@ def transform_subroutine(self, routine, **kwargs):
4047
single_variable_declaration(routine)
4148
self.add_temp(routine)
4249
self.add_field(routine)
50+
self.add_derived(routine)
4351
#call add_arrays etc...
4452

4553
def process_parallel_region(self, routine, region):
@@ -61,11 +69,15 @@ def process_parallel_region(self, routine, region):
6169
region.append(dr_hook_calls[1])
6270

6371
region_map_temp= self.decl_local_array(routine, region)
72+
region_map_derived= self.decl_derived_types(routine, region)
6473

6574
for var_name in region_map_temp:
6675
if var_name not in self.routine_map_temp:
6776
self.routine_map_temp[var_name]=region_map_temp[var_name]
6877

78+
for var_name in region_map_derived:
79+
if var_name not in self.routine_map_derived:
80+
self.routine_map_derived[var_name]=region_map_derived[var_name]
6981

7082

7183
@staticmethod
@@ -178,4 +190,51 @@ def add_field(self, routine):
178190
routine, cdname='DELETE_TEMPORARIES',
179191
handle=sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine)
180192
)
181-
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))
193+
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))
194+
195+
def decl_derived_types(self, routine, region):
196+
region_map_derived = {}
197+
derived = [var for var in FindVariables().visit(region) if var.name_parts[0] in routine.arguments]
198+
for var in derived :
199+
200+
key = f"{routine.variable_map[var.name_parts[0]].type.dtype.name}%{'%'.join(var.name_parts[1:])}"
201+
if key in self.map_index:
202+
value = self.map_index[key]
203+
# Creating the pointer on the data : YL_A
204+
data_name = f"Z_{var.name.replace('%', '_')}"
205+
if "REAL" and "JPRB" in value[0]:
206+
data_type = SymbolAttributes(
207+
dtype=BasicType.REAL, kind=routine.symbol_map['JPRB'],
208+
pointer=True
209+
)
210+
data_dim = value[2] + 1
211+
data_shape = (sym.RangeIndex((None, None)),) * data_dim
212+
ptr_var = sym.Variable(name=data_name, type=data_type, dimensions=data_shape, scope=routine)
213+
214+
else:
215+
raise NotImplementedError("This type isn't implemented yet")
216+
217+
# Creating the pointer on the field api object : YL%FA, YL%F_A...
218+
if routine.variable_map[var.name_parts[0]].type.dtype.name=="MF_PHYS_SURF_TYPE":
219+
# YL%PA becomes YL%F_A
220+
field_name = f"{'%'.join(var.name_parts[:-1])}%F_{var.name_parts[-1][1:]}"
221+
elif routine.variable_map[var.name_parts[0]].type.dtype.name=="FIELD_VARIABLES":
222+
# YL%A becomes YL%FA
223+
field_name = f"{'%'.join(var.name_parts[:-1])}%F{var.name_parts[-1]}"
224+
if var.name_parts[-1]=="P": #YL%FP = YL%FT0
225+
field_name = f"{field_name[-1]}T0"
226+
else:
227+
# YL%A becomes YL%F_A
228+
field_name = f"{'%'.join(var.name_parts[:-1])}%F_{var.name_parts[-1]}"
229+
field_ptr_var = var.clone(name=field_name)
230+
region_map_derived[var.name] = [field_ptr_var, ptr_var]
231+
return(region_map_derived)
232+
233+
def add_derived(self, routine):
234+
ptr_var=()
235+
for value in self.routine_map_derived.values():
236+
dcl = ir.VariableDeclaration(
237+
symbols=(value[1],)
238+
)
239+
ptr_var += (dcl,)
240+
routine.spec.append(ptr_var)

0 commit comments

Comments
 (0)