Skip to content

Commit e77e88d

Browse files
committed
Add item to transformation, create map_routine and map_region
1 parent 74fed5d commit e77e88d

File tree

2 files changed

+185
-122
lines changed

2 files changed

+185
-122
lines changed

transformations/tests/test_parallel_routine_dispatch.py

+42-27
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytest
1212

1313
from loki.frontend import available_frontends, OMNI
14-
from loki import Sourcefile, FindNodes, CallStatement, fgen, Conditional
14+
from loki import Sourcefile, FindNodes, CallStatement, fgen, Conditional, ProcedureItem
1515

1616
from transformations.parallel_routine_dispatch import ParallelRoutineDispatchTransformation
1717

@@ -25,13 +25,14 @@ def fixture_here():
2525
def test_parallel_routine_dispatch_dr_hook(here, frontend):
2626

2727
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
28+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
2829
routine = source['dispatch_routine']
2930

3031
calls = FindNodes(CallStatement).visit(routine.body)
3132
assert len(calls) == 3
3233

3334
transformation = ParallelRoutineDispatchTransformation()
34-
transformation.apply(source['dispatch_routine'])
35+
transformation.apply(source['dispatch_routine'], item=item)
3536

3637
calls = [call for call in FindNodes(CallStatement).visit(routine.body) if call.name.name=='DR_HOOK']
3738
assert len(calls) == 8
@@ -40,10 +41,11 @@ def test_parallel_routine_dispatch_dr_hook(here, frontend):
4041
def test_parallel_routine_dispatch_decl_local_arrays(here, frontend):
4142

4243
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
44+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
4345
routine = source['dispatch_routine']
4446

4547
transformation = ParallelRoutineDispatchTransformation()
46-
transformation.apply(source['dispatch_routine'])
48+
transformation.apply(source['dispatch_routine'], item=item)
4749
var_lst=["YL_ZRDG_CVGQ", "ZRDG_CVGQ", "YL_ZRDG_MU0LU", "ZRDG_MU0LU", "YL_ZRDG_MU0M", "ZRDG_MU0M", "YL_ZRDG_MU0N", "ZRDG_MU0N", "YL_ZRDG_MU0", "ZRDG_MU0"]
4850
dcls = [dcl for dcl in routine.declarations if dcl.symbols[0].name in var_lst]
4951
str_dcls = ""
@@ -65,10 +67,11 @@ def test_parallel_routine_dispatch_decl_local_arrays(here, frontend):
6567
def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
6668

6769
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
70+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
6871
routine = source['dispatch_routine']
6972

7073
transformation = ParallelRoutineDispatchTransformation()
71-
transformation.apply(source['dispatch_routine'])
74+
transformation.apply(source['dispatch_routine'], item=item)
7275

7376
var_lst = ["YL_ZRDG_CVGQ", "ZRDG_CVGQ", "YL_ZRDG_MU0LU", "ZRDG_MU0LU", "YL_ZRDG_MU0M", "ZRDG_MU0M", "YL_ZRDG_MU0N", "ZRDG_MU0N", "YL_ZRDG_MU0", "ZRDG_MU0"]
7477
field_create = ["CALL FIELD_NEW(YL_ZRDG_CVGQ, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KFLEVG, YDCPG_OPTS%KGPBLKS /), LBOUNDS=(/ 0, 1 /), &\n& PERSISTENT=.true.)",
@@ -105,10 +108,11 @@ def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
105108
def test_parallel_routine_dispatch_derived_dcl(here, frontend):
106109

107110
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
111+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
108112
routine = source['dispatch_routine']
109113

110114
transformation = ParallelRoutineDispatchTransformation()
111-
transformation.apply(source['dispatch_routine'])
115+
transformation.apply(source['dispatch_routine'], item=item)
112116

113117
dcls = [fgen(dcl) for dcl in routine.spec.body[-13:-1]]
114118

@@ -132,10 +136,11 @@ def test_parallel_routine_dispatch_derived_dcl(here, frontend):
132136
def test_parallel_routine_dispatch_derived_var(here, frontend):
133137

134138
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
139+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
135140
routine = source['dispatch_routine']
136141

137142
transformation = ParallelRoutineDispatchTransformation()
138-
transformation.apply(source['dispatch_routine'])
143+
transformation.apply(source['dispatch_routine'], item=item)
139144

140145

141146
test_map = {
@@ -153,8 +158,9 @@ def test_parallel_routine_dispatch_derived_var(here, frontend):
153158
"YDCPG_DYN0%CTY%EVEL" : ["YDCPG_DYN0%CTY%F_EVEL", "Z_YDCPG_DYN0_CTY_EVEL"],
154159
"YDMF_PHYS_SURF%GSD_VF%PZ0F" : ["YDMF_PHYS_SURF%GSD_VF%F_Z0F", "Z_YDMF_PHYS_SURF_GSD_VF_PZ0F"]
155160
}
156-
for var_name in transformation.routine_map_derived:
157-
value = transformation.routine_map_derived[var_name]
161+
routine_map_derived = item.trafo_data['create_parallel']['map_routine']['routine_map_derived']
162+
for var_name in routine_map_derived:
163+
value = routine_map_derived[var_name]
158164
field_ptr = value[0]
159165
ptr = value[1]
160166

@@ -165,12 +171,13 @@ def test_parallel_routine_dispatch_derived_var(here, frontend):
165171
def test_parallel_routine_dispatch_get_data(here, frontend):
166172

167173
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
174+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
168175
routine = source['dispatch_routine']
169176

170177
transformation = ParallelRoutineDispatchTransformation()
171-
transformation.apply(source['dispatch_routine'])
178+
transformation.apply(source['dispatch_routine'], item=item)
172179

173-
get_data = transformation.get_data
180+
get_data = item.trafo_data['create_parallel']['map_region']['get_data']
174181

175182
test_get_data = {}
176183
# test_get_data["OpenMP"] = """
@@ -276,7 +283,7 @@ def test_parallel_routine_dispatch_get_data(here, frontend):
276283
### routine = source['dispatch_routine']
277284
###
278285
### transformation = ParallelRoutineDispatchTransformation()
279-
### transformation.apply(source['dispatch_routine'])
286+
### transformation.apply(source['dispatch_routine'], item=item)
280287
###
281288
### get_data = transformation.get_data
282289
###
@@ -294,12 +301,13 @@ def test_parallel_routine_dispatch_get_data(here, frontend):
294301
def test_parallel_routine_dispatch_synchost(here, frontend):
295302

296303
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
304+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
297305
routine = source['dispatch_routine']
298306

299307
transformation = ParallelRoutineDispatchTransformation()
300-
transformation.apply(source['dispatch_routine'])
308+
transformation.apply(source['dispatch_routine'], item=item)
301309

302-
synchost = transformation.synchost[0]
310+
synchost = item.trafo_data['create_parallel']['map_region']['synchost']
303311

304312
test_synchost = """IF (LSYNCHOST('DISPATCH_ROUTINE:CPPHINP')) THEN
305313
IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:SYNCHOST', 0, ZHOOK_HANDLE_FIELD_API)
@@ -332,12 +340,13 @@ def test_parallel_routine_dispatch_synchost(here, frontend):
332340
def test_parallel_routine_dispatch_nullify(here, frontend):
333341

334342
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
343+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
335344
routine = source['dispatch_routine']
336345

337346
transformation = ParallelRoutineDispatchTransformation()
338-
transformation.apply(source['dispatch_routine'])
347+
transformation.apply(source['dispatch_routine'], item=item)
339348

340-
nullify = transformation.nullify
349+
nullify = item.trafo_data['create_parallel']['map_region']['nullify']
341350

342351
test_nullify = """
343352
IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:NULLIFY', 0, ZHOOK_HANDLE_FIELD_API)
@@ -370,12 +379,13 @@ def test_parallel_routine_dispatch_nullify(here, frontend):
370379
def test_parallel_routine_dispatch_compute_openmp(here, frontend):
371380

372381
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
382+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
373383
routine = source['dispatch_routine']
374384

375385
transformation = ParallelRoutineDispatchTransformation()
376-
transformation.apply(source['dispatch_routine'])
386+
transformation.apply(source['dispatch_routine'], item=item)
377387

378-
map_compute = transformation.compute
388+
map_compute = item.trafo_data['create_parallel']['map_region']['compute']
379389
compute_openmp = map_compute['OpenMP']
380390

381391
test_compute= """
@@ -423,12 +433,13 @@ def test_parallel_routine_dispatch_compute_openmp(here, frontend):
423433
def test_parallel_routine_dispatch_compute_openmpscc(here, frontend):
424434

425435
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
436+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
426437
routine = source['dispatch_routine']
427438

428439
transformation = ParallelRoutineDispatchTransformation()
429-
transformation.apply(source['dispatch_routine'])
440+
transformation.apply(source['dispatch_routine'], item=item)
430441

431-
map_compute = transformation.compute
442+
map_compute = item.trafo_data['create_parallel']['map_region']['compute']
432443
compute_openmpscc = map_compute['OpenMPSingleColumn']
433444

434445
test_compute= """
@@ -490,12 +501,13 @@ def test_parallel_routine_dispatch_compute_openmpscc(here, frontend):
490501
def test_parallel_routine_dispatch_compute_openaccscc(here, frontend):
491502

492503
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
504+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
493505
routine = source['dispatch_routine']
494506

495507
transformation = ParallelRoutineDispatchTransformation()
496-
transformation.apply(source['dispatch_routine'])
508+
transformation.apply(source['dispatch_routine'], item=item)
497509

498-
map_compute = transformation.compute
510+
map_compute = item.trafo_data['create_parallel']['map_region']['compute']
499511
compute_openaccscc = map_compute['OpenACCSingleColumn']
500512

501513
test_compute = """
@@ -576,12 +588,13 @@ def test_parallel_routine_dispatch_compute_openaccscc(here, frontend):
576588
def test_parallel_routine_dispatch_variables(here, frontend):
577589

578590
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
591+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
579592
routine = source['dispatch_routine']
580593

581594
transformation = ParallelRoutineDispatchTransformation()
582-
transformation.apply(source['dispatch_routine'])
595+
transformation.apply(source['dispatch_routine'], item=item)
583596

584-
variables = transformation.dcls
597+
variables = item.trafo_data['create_parallel']['map_routine']['dcls']
585598

586599
test_variables = '''TYPE(CPG_BNDS_TYPE), INTENT(IN) :: YLCPG_BNDS
587600
TYPE(STACK) :: YLSTACK
@@ -597,12 +610,13 @@ def test_parallel_routine_dispatch_imports(here, frontend):
597610
#TODO : add imports to _parallel routines
598611

599612
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
613+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
600614
routine = source['dispatch_routine']
601615

602616
transformation = ParallelRoutineDispatchTransformation()
603-
transformation.apply(source['dispatch_routine'])
617+
transformation.apply(source['dispatch_routine'], item=item)
604618

605-
imports = transformation.imports
619+
imports = item.trafo_data['create_parallel']['map_routine']['imports']
606620

607621
test_imports = """
608622
USE ACPY_MOD
@@ -621,11 +635,12 @@ def test_parallel_routine_dispatch_new_callee_imports(here, frontend):
621635
#TODO : add imports to _parallel routines
622636

623637
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
638+
item = ProcedureItem(name='parallel_routine_dispatch', source=source)
624639
routine = source['dispatch_routine']
625640

626641
transformation = ParallelRoutineDispatchTransformation()
627-
transformation.apply(source['dispatch_routine'])
642+
transformation.apply(source['dispatch_routine'], item=item)
628643

629-
imports = transformation.callee_imports
644+
imports = item.trafo_data['create_parallel']['map_routine']['callee_imports']
630645

631646
assert fgen(imports) == '#include "cpphinp_openacc.intfb.h"'

0 commit comments

Comments
 (0)