Skip to content

Commit 180695a

Browse files
committed
Add compute OpenMP region (+ add a shape for ptr_var in decl_derived_types)
1 parent 4f69e67 commit 180695a

File tree

2 files changed

+164
-14
lines changed

2 files changed

+164
-14
lines changed

transformations/tests/test_parallel_routine_dispatch.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -363,4 +363,58 @@ def test_parallel_routine_dispatch_nullify(here, frontend):
363363
"""
364364

365365
for node in nullify:
366-
assert fgen(node) in test_nullify
366+
assert fgen(node) in test_nullify
367+
368+
369+
@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
370+
def test_parallel_routine_dispatch_compute_openmp(here, frontend):
371+
372+
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
373+
routine = source['dispatch_routine']
374+
375+
transformation = ParallelRoutineDispatchTransformation()
376+
transformation.apply(source['dispatch_routine'])
377+
378+
map_compute = transformation.compute
379+
compute_openmp = map_compute['OpenMP']
380+
381+
test_compute= """
382+
IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:COMPUTE', 0, ZHOOK_HANDLE_COMPUTE)
383+
384+
CALL YLCPG_BNDS%INIT(YDCPG_OPTS)
385+
!$OMP PARALLEL DO PRIVATE JBLK FIRSTPRIVATE( YLCPG_BNDS )
386+
387+
DO JBLK=1,YDCPG_OPTS%KGPBLKS
388+
CALL YLCPG_BNDS%UPDATE(JBLK)
389+
390+
CALL CPPHINP (YDGEOMETRY, YDMODEL, YLCPG_BNDS%KIDIA, YLCPG_BNDS%KFDIA, Z_YDVARS_GEOMETRY_GEMU_T0&
391+
&(:, JBLK), Z_YDVARS_GEOMETRY_GELAM_T0(:, JBLK), Z_YDVARS_U_T0(:,:, JBLK), Z_YDVARS_V_T0(:&
392+
&,:, JBLK), Z_YDVARS_Q_T0(:,:, JBLK), Z_YDVARS_Q_DL(:,:, JBLK), Z_YDVARS_Q_DM(:,:, JBLK), Z_YDVARS_CVGQ_DL &
393+
&(:,:, JBLK), Z_YDVARS_CVGQ_DM(:,:, JBLK), Z_YDCPG_PHY0_XYB_RDELP(:,:, JBLK), Z_YDCPG_DYN0_CTY_EVEL&
394+
&(:,:, JBLK), Z_YDVARS_CVGQ_T0(:,:, JBLK), ZRDG_MU0(:, JBLK), ZRDG_MU0LU(:, JBLK), ZRDG_MU0M&
395+
&(:, JBLK), ZRDG_MU0N(:, JBLK), ZRDG_CVGQ(:,:, JBLK), Z_YDMF_PHYS_SURF_GSD_VF_PZ0F(:, JBLK))
396+
ENDDO
397+
398+
IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:COMPUTE', 1, ZHOOK_HANDLE_COMPUTE)
399+
"""
400+
401+
test_call_var = ["YDGEOMETRY", "YDMODEL", "YLCPG_BNDS%KIDIA", "YLCPG_BNDS%KFDIA",
402+
"Z_YDVARS_GEOMETRY_GEMU_T0(:, JBLK)", "Z_YDVARS_GEOMETRY_GELAM_T0(:, JBLK)", "Z_YDVARS_U_T0(:, :, JBLK)",
403+
"Z_YDVARS_V_T0(:, :, JBLK)", "Z_YDVARS_Q_T0(:, :, JBLK)", "Z_YDVARS_Q_DL(:, :, JBLK)",
404+
"Z_YDVARS_Q_DM(:, :, JBLK)", "Z_YDVARS_CVGQ_DL(:, :, JBLK)", "Z_YDVARS_CVGQ_DM(:, :, JBLK)",
405+
"Z_YDCPG_PHY0_XYB_RDELP(:, :, JBLK)", "Z_YDCPG_DYN0_CTY_EVEL(:, :, JBLK)",
406+
"Z_YDVARS_CVGQ_T0(:, :, JBLK)", "ZRDG_MU0(:, JBLK)", "ZRDG_MU0LU(:, JBLK)",
407+
"ZRDG_MU0M(:, JBLK)", "ZRDG_MU0N(:, JBLK)", "ZRDG_CVGQ(:, :, JBLK)",
408+
"Z_YDMF_PHYS_SURF_GSD_VF_PZ0F(:, JBLK)"
409+
]
410+
411+
for node in compute_openmp[:3]:
412+
assert fgen(node) in test_compute
413+
loop = compute_openmp[3]
414+
assert fgen(loop.bounds) == '1,YDCPG_OPTS%KGPBLKS'
415+
assert fgen(loop.variable) == 'JBLK'
416+
assert fgen(loop.body[0]) == 'CALL YLCPG_BNDS%UPDATE(JBLK)'
417+
call = loop.body[1]
418+
assert fgen(call.name) == 'CPPHINP'
419+
for arg in call.arguments:
420+
assert fgen(arg) in test_call_var

transformations/transformations/parallel_routine_dispatch.py

+109-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from loki import (
1515
FindVariables, DerivedType, SymbolAttributes,
1616
Array, single_variable_declaration, Transformer,
17-
BasicType
17+
BasicType, as_tuple
1818
)
1919
import pickle
2020
import os
@@ -31,6 +31,12 @@ def __init__(self):
3131
"KLON", "YDCPG_OPTS%KLON", "YDGEOMETRY%YRDIM%NPROMA",
3232
"KPROMA", "YDDIM%NPROMA", "NPROMA"
3333
]
34+
self.map_compute = {
35+
"OpenMP" : self.create_compute_openmp,
36+
"OpenMPSingleColumn" : self.create_compute_openmpscc,
37+
"OpenACCSingleColumn" : self.create_compute_openaccscc
38+
}
39+
3440
#TODO : do smthg for opening field_index.pkl
3541
with open(os.getcwd()+"/transformations/transformations/field_index.pkl", 'rb') as fp:
3642
self.map_index = pickle.load(fp)
@@ -46,6 +52,7 @@ def __init__(self):
4652
self.routine_map_derived = {}
4753

4854
def transform_subroutine(self, routine, **kwargs):
55+
self.get_cpg(routine)
4956
with pragma_regions_attached(routine):
5057
for region in FindNodes(ir.PragmaRegion).visit(routine.body):
5158
if is_loki_pragma(region.pragma):
@@ -56,6 +63,7 @@ def transform_subroutine(self, routine, **kwargs):
5663
self.add_derived(routine)
5764
#call add_arrays etc...
5865

66+
5967
def process_parallel_region(self, routine, region):
6068
pragma_content = region.pragma.content.split(maxsplit=1)
6169
pragma_content = [entry.split('=', maxsplit=1) for entry in pragma_content[1].split(',')]
@@ -79,22 +87,21 @@ def process_parallel_region(self, routine, region):
7987
region_map_derived= self.decl_derived_types(routine, region)
8088

8189
self.get_data = {}
90+
self.compute = {}
8291
### self.synchost = {} #synchost same for all the targets
8392
### self.nullify = {} #synchost same for all the targets
8493

8594
self.synchost = self.create_synchost(routine, region_name, region_map_derived, region_map_temp)
8695
self.nullify = self.create_nullify(routine, region_name, region_map_derived, region_map_temp)
8796

88-
8997
for target in pragma_attrs['target']:
90-
9198
# Q : I would like get_data, synchost and nullify not be members of the Transformation object, however, I need them to run the test...
9299
# A : maybe have them as members of the routine while
93100
# Is there an object to handle data that is needed for tests ?
94101
# get_data = self.create_pt_sync(routine, region_name, True, region_map_derived, region_map_temp)
95102
# synchost = self.create_synchost(routine, region_name, True, region_map_derived, region_map_temp)
96103
# nullify = self.create_nullify(routine, region_name, True, region_map_derived, region_map_temp)
97-
self.process_target(routine, target, region_name, region_map_temp, region_map_derived)
104+
self.process_target(routine, region, region_name, region_map_temp, region_map_derived, target)
98105
for var_name in region_map_temp:
99106
if var_name not in self.routine_map_temp:
100107
self.routine_map_temp[var_name]=region_map_temp[var_name]
@@ -103,12 +110,9 @@ def process_parallel_region(self, routine, region):
103110
if var_name not in self.routine_map_derived:
104111
self.routine_map_derived[var_name]=region_map_derived[var_name]
105112

106-
def process_target(self, routine, target, region_name, region_map_temp, region_map_derived):
107-
113+
def process_target(self, routine, region, region_name, region_map_temp, region_map_derived, target):
108114
self.get_data[target] = self.create_pt_sync(routine, target, region_name, True, region_map_derived, region_map_temp)
109-
### self.synchost[target] = self.create_synchost(routine, target, region_name, region_map_derived, region_map_temp)
110-
### self.nullify[target] = self.create_nullify(routine, target, region_name, region_map_derived, region_map_temp)
111-
115+
self.compute[target] = self.map_compute[target](routine, region, region_name, region_map_temp, region_map_derived)
112116

113117
@staticmethod
114118
def create_dr_hook_calls(scope, cdname, handle):
@@ -232,12 +236,12 @@ def decl_derived_types(self, routine, region):
232236
# Creating the pointer on the data : YL_A
233237
data_name = f"Z_{var.name.replace('%', '_')}"
234238
if "REAL" and "JPRB" in value[0]:
239+
data_dim = value[2] + 1
240+
data_shape = (sym.RangeIndex((None, None)),) * data_dim
235241
data_type = SymbolAttributes(
236242
dtype=BasicType.REAL, kind=routine.symbol_map['JPRB'],
237-
pointer=True
243+
pointer=True, shape=data_shape
238244
)
239-
data_dim = value[2] + 1
240-
data_shape = (sym.RangeIndex((None, None)),) * data_dim
241245
ptr_var = sym.Variable(name=data_name, type=data_type, dimensions=data_shape, scope=routine)
242246

243247
else:
@@ -293,7 +297,6 @@ def create_pt_sync(self, routine, target, region_name, is_get_data, region_map_d
293297

294298
call = sym.InlineCall(sym.Variable(name=f"{sync_name}_{intent}"), parameters=(var[0],))
295299
sync_data += [ir.Assignment(lhs=var[1].clone(dimensions=None), rhs=call, ptr=True)]
296-
#sync_data += [ir.Assignment(lhs=(var[1],), rhs=(call,), ptr=True)]
297300

298301
sync_data.append(dr_hook_calls[1])
299302

@@ -315,3 +318,96 @@ def create_nullify(self, routine, region_name, region_map_derived, region_map_te
315318
nullify.append(dr_hook_calls[1])
316319
return nullify
317320

321+
def get_cpg(self,routine):
322+
#Assuming CPG_OPTS_TYPE and CPG_BNDS_TYPE are the same in all the routine.
323+
found_opts = False
324+
found_bnds = False
325+
for var in FindVariables().visit(routine.spec):
326+
if var.type.dtype.name=="CPG_OPTS_TYPE":
327+
self.cpg_opts = var
328+
found_opts = True
329+
if var.type.dtype.name=="CPG_BNDS_TYPE":
330+
self.cpg_bnds = var
331+
found_bnds = True
332+
if (found_opts and found_bnds) :
333+
if "YD" in self.cpg_bnds.name:
334+
lcpg_bnds_name = self.cpg_bnds.name.replace("YD", "YL")
335+
self.lcpg_bnds = sym.Variable(name=lcpg_bnds_name, scope=routine)
336+
dcl = ir.VariableDeclaration(symbols=as_tuple(self.lcpg_bnds))
337+
routine.spec.append(dcl)
338+
data_type = SymbolAttributes(
339+
dtype=BasicType.INTEGER, kind=routine.symbol_map['JPIM']
340+
)
341+
self.jblk = sym.Variable(name="JBLK", type=data_type, scope=routine)
342+
routine.spec.append(self.jblk)
343+
return
344+
else:
345+
raise Exception(f"cpg_bnds unexpected name : {self.cpg_bnds.name}")
346+
347+
def update_args(self, arg, region_map):
348+
new_arg = region_map[arg.name][1]
349+
dim = len(new_arg.dimensions)
350+
#dim = len(new_arg.shape)
351+
new_dimensions = (sym.RangeIndex((None, None)),) * (dim-1)
352+
new_dimensions += (self.jblk,)
353+
return new_arg.clone(dimensions=new_dimensions)
354+
355+
def create_compute_openmp(self, routine, region, region_name, region_map_temp, region_map_derived):
356+
#ylcpg_bnds : new var to add to spec, type(ylcpg)=type(cpg_bnds)=CPG_BNDS_TYPE
357+
358+
#hook_compute 0
359+
#call ylcpg_bnds%init(ydcpg_opts)
360+
#!$omp parallel do private (jblk) firstprivate (ylcpg_bnds)
361+
#do jblk = 1, ydcpg_opts%kgpblks
362+
# call ylcpg_bnds%update(jblk)
363+
# call callee(ydgeometry, ydmodel, ylcpg_bnds%kidia, ... (...BLK))
364+
#enddo
365+
#hook_compute 1
366+
367+
init = ir.CallStatement(
368+
name=routine.resolve_typebound_var(f"{self.lcpg_bnds.name}%INIT"),
369+
arguments=(self.cpg_opts,))
370+
#TODO : generate lst_private !!!!
371+
lst_private = "JBLK"
372+
pragma = ir.Pragma(keyword="OMP", content=f"PARALLEL DO PRIVATE {lst_private} FIRSTPRIVATE ({self.lcpg_bnds})")
373+
update = ir.CallStatement(
374+
name=routine.resolve_typebound_var(f"{self.lcpg_bnds.name}%UPDATE"),
375+
arguments=(self.jblk,)
376+
)
377+
#TODO : musn't be call but the body of the region here??
378+
379+
new_calls = []
380+
for call in FindNodes(ir.CallStatement).visit(region):
381+
if call.name!="DR_HOOK":
382+
# for var in chain(region_map_temp.values(), region_map_derived.values()):
383+
new_arguments = []
384+
for arg in call.arguments:
385+
if arg.name in region_map_temp:
386+
new_arguments +=[self.update_args(arg, region_map_temp)]
387+
elif arg.name in region_map_derived:
388+
new_arguments +=[self.update_args(arg, region_map_derived)]
389+
elif arg.name_parts[0]==self.cpg_bnds.name:
390+
new_arguments += [routine.resolve_typebound_var(f"{self.lcpg_bnds}%{arg.name_parts[1]}")]
391+
else:
392+
new_arguments +=[arg]
393+
new_calls += [call.clone(arguments=as_tuple(new_arguments))]
394+
395+
new_calls = tuple(new_calls)
396+
397+
loop_body = (update,) + new_calls
398+
loop = ir.Loop(variable=self.jblk, bounds=sym.LoopRange((1,routine.resolve_typebound_var(f"{self.cpg_opts}%KGPBLKS"))), body=loop_body)
399+
dr_hook_calls = self.create_dr_hook_calls(
400+
routine, f"{routine.name}:{region_name}:COMPUTE",
401+
sym.Variable(name='ZHOOK_HANDLE_COMPUTE', scope=routine)
402+
)
403+
new_region = (dr_hook_calls[0], init, pragma, loop, dr_hook_calls[1])
404+
return(new_region)
405+
# TODO : YLCPG_BNDS%INIT
406+
# TODO : OMP PARALLEL
407+
# sym.DeferredTypeSymbol
408+
#call : call.clone(name=..., args= tuple of the region var + dimensions!!!)
409+
410+
def create_compute_openmpscc(self, routine, region, region_name, region_map_temp, region_map_derived):
411+
pass
412+
def create_compute_openaccscc(self, routine, region, region_name, region_map_temp, region_map_derived):
413+
pass

0 commit comments

Comments
 (0)