11
11
import pytest
12
12
13
13
from loki .frontend import available_frontends , OMNI
14
- from loki import Sourcefile , FindNodes , CallStatement , fgen , Conditional
14
+ from loki import Sourcefile , FindNodes , CallStatement , fgen , Conditional , ProcedureItem
15
15
16
16
from transformations .parallel_routine_dispatch import ParallelRoutineDispatchTransformation
17
17
@@ -25,13 +25,14 @@ def fixture_here():
25
25
def test_parallel_routine_dispatch_dr_hook (here , frontend ):
26
26
27
27
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
28
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
28
29
routine = source ['dispatch_routine' ]
29
30
30
31
calls = FindNodes (CallStatement ).visit (routine .body )
31
32
assert len (calls ) == 3
32
33
33
34
transformation = ParallelRoutineDispatchTransformation ()
34
- transformation .apply (source ['dispatch_routine' ])
35
+ transformation .apply (source ['dispatch_routine' ], item = item )
35
36
36
37
calls = [call for call in FindNodes (CallStatement ).visit (routine .body ) if call .name .name == 'DR_HOOK' ]
37
38
assert len (calls ) == 8
@@ -40,10 +41,11 @@ def test_parallel_routine_dispatch_dr_hook(here, frontend):
40
41
def test_parallel_routine_dispatch_decl_local_arrays (here , frontend ):
41
42
42
43
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
44
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
43
45
routine = source ['dispatch_routine' ]
44
46
45
47
transformation = ParallelRoutineDispatchTransformation ()
46
- transformation .apply (source ['dispatch_routine' ])
48
+ transformation .apply (source ['dispatch_routine' ], item = item )
47
49
var_lst = ["YL_ZRDG_CVGQ" , "ZRDG_CVGQ" , "YL_ZRDG_MU0LU" , "ZRDG_MU0LU" , "YL_ZRDG_MU0M" , "ZRDG_MU0M" , "YL_ZRDG_MU0N" , "ZRDG_MU0N" , "YL_ZRDG_MU0" , "ZRDG_MU0" ]
48
50
dcls = [dcl for dcl in routine .declarations if dcl .symbols [0 ].name in var_lst ]
49
51
str_dcls = ""
@@ -65,10 +67,11 @@ def test_parallel_routine_dispatch_decl_local_arrays(here, frontend):
65
67
def test_parallel_routine_dispatch_decl_field_create_delete (here , frontend ):
66
68
67
69
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
70
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
68
71
routine = source ['dispatch_routine' ]
69
72
70
73
transformation = ParallelRoutineDispatchTransformation ()
71
- transformation .apply (source ['dispatch_routine' ])
74
+ transformation .apply (source ['dispatch_routine' ], item = item )
72
75
73
76
var_lst = ["YL_ZRDG_CVGQ" , "ZRDG_CVGQ" , "YL_ZRDG_MU0LU" , "ZRDG_MU0LU" , "YL_ZRDG_MU0M" , "ZRDG_MU0M" , "YL_ZRDG_MU0N" , "ZRDG_MU0N" , "YL_ZRDG_MU0" , "ZRDG_MU0" ]
74
77
field_create = ["CALL FIELD_NEW(YL_ZRDG_CVGQ, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KFLEVG, YDCPG_OPTS%KGPBLKS /), LBOUNDS=(/ 0, 1 /), &\n & PERSISTENT=.true.)" ,
@@ -105,10 +108,11 @@ def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
105
108
def test_parallel_routine_dispatch_derived_dcl (here , frontend ):
106
109
107
110
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
111
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
108
112
routine = source ['dispatch_routine' ]
109
113
110
114
transformation = ParallelRoutineDispatchTransformation ()
111
- transformation .apply (source ['dispatch_routine' ])
115
+ transformation .apply (source ['dispatch_routine' ], item = item )
112
116
113
117
dcls = [fgen (dcl ) for dcl in routine .spec .body [- 13 :- 1 ]]
114
118
@@ -132,10 +136,11 @@ def test_parallel_routine_dispatch_derived_dcl(here, frontend):
132
136
def test_parallel_routine_dispatch_derived_var (here , frontend ):
133
137
134
138
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
139
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
135
140
routine = source ['dispatch_routine' ]
136
141
137
142
transformation = ParallelRoutineDispatchTransformation ()
138
- transformation .apply (source ['dispatch_routine' ])
143
+ transformation .apply (source ['dispatch_routine' ], item = item )
139
144
140
145
141
146
test_map = {
@@ -153,8 +158,9 @@ def test_parallel_routine_dispatch_derived_var(here, frontend):
153
158
"YDCPG_DYN0%CTY%EVEL" : ["YDCPG_DYN0%CTY%F_EVEL" , "Z_YDCPG_DYN0_CTY_EVEL" ],
154
159
"YDMF_PHYS_SURF%GSD_VF%PZ0F" : ["YDMF_PHYS_SURF%GSD_VF%F_Z0F" , "Z_YDMF_PHYS_SURF_GSD_VF_PZ0F" ]
155
160
}
156
- for var_name in transformation .routine_map_derived :
157
- value = transformation .routine_map_derived [var_name ]
161
+ routine_map_derived = item .trafo_data ['create_parallel' ]['map_routine' ]['routine_map_derived' ]
162
+ for var_name in routine_map_derived :
163
+ value = routine_map_derived [var_name ]
158
164
field_ptr = value [0 ]
159
165
ptr = value [1 ]
160
166
@@ -165,12 +171,13 @@ def test_parallel_routine_dispatch_derived_var(here, frontend):
165
171
def test_parallel_routine_dispatch_get_data (here , frontend ):
166
172
167
173
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
174
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
168
175
routine = source ['dispatch_routine' ]
169
176
170
177
transformation = ParallelRoutineDispatchTransformation ()
171
- transformation .apply (source ['dispatch_routine' ])
178
+ transformation .apply (source ['dispatch_routine' ], item = item )
172
179
173
- get_data = transformation . get_data
180
+ get_data = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' get_data' ]
174
181
175
182
test_get_data = {}
176
183
# test_get_data["OpenMP"] = """
@@ -276,7 +283,7 @@ def test_parallel_routine_dispatch_get_data(here, frontend):
276
283
### routine = source['dispatch_routine']
277
284
###
278
285
### transformation = ParallelRoutineDispatchTransformation()
279
- ### transformation.apply(source['dispatch_routine'])
286
+ ### transformation.apply(source['dispatch_routine'], item=item )
280
287
###
281
288
### get_data = transformation.get_data
282
289
###
@@ -294,12 +301,13 @@ def test_parallel_routine_dispatch_get_data(here, frontend):
294
301
def test_parallel_routine_dispatch_synchost (here , frontend ):
295
302
296
303
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
304
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
297
305
routine = source ['dispatch_routine' ]
298
306
299
307
transformation = ParallelRoutineDispatchTransformation ()
300
- transformation .apply (source ['dispatch_routine' ])
308
+ transformation .apply (source ['dispatch_routine' ], item = item )
301
309
302
- synchost = transformation . synchost [ 0 ]
310
+ synchost = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ 'synchost' ]
303
311
304
312
test_synchost = """IF (LSYNCHOST('DISPATCH_ROUTINE:CPPHINP')) THEN
305
313
IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:SYNCHOST', 0, ZHOOK_HANDLE_FIELD_API)
@@ -332,12 +340,13 @@ def test_parallel_routine_dispatch_synchost(here, frontend):
332
340
def test_parallel_routine_dispatch_nullify (here , frontend ):
333
341
334
342
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
343
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
335
344
routine = source ['dispatch_routine' ]
336
345
337
346
transformation = ParallelRoutineDispatchTransformation ()
338
- transformation .apply (source ['dispatch_routine' ])
347
+ transformation .apply (source ['dispatch_routine' ], item = item )
339
348
340
- nullify = transformation . nullify
349
+ nullify = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' nullify' ]
341
350
342
351
test_nullify = """
343
352
IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:NULLIFY', 0, ZHOOK_HANDLE_FIELD_API)
@@ -370,12 +379,13 @@ def test_parallel_routine_dispatch_nullify(here, frontend):
370
379
def test_parallel_routine_dispatch_compute_openmp (here , frontend ):
371
380
372
381
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
382
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
373
383
routine = source ['dispatch_routine' ]
374
384
375
385
transformation = ParallelRoutineDispatchTransformation ()
376
- transformation .apply (source ['dispatch_routine' ])
386
+ transformation .apply (source ['dispatch_routine' ], item = item )
377
387
378
- map_compute = transformation . compute
388
+ map_compute = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' compute' ]
379
389
compute_openmp = map_compute ['OpenMP' ]
380
390
381
391
test_compute = """
@@ -423,12 +433,13 @@ def test_parallel_routine_dispatch_compute_openmp(here, frontend):
423
433
def test_parallel_routine_dispatch_compute_openmpscc (here , frontend ):
424
434
425
435
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
436
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
426
437
routine = source ['dispatch_routine' ]
427
438
428
439
transformation = ParallelRoutineDispatchTransformation ()
429
- transformation .apply (source ['dispatch_routine' ])
440
+ transformation .apply (source ['dispatch_routine' ], item = item )
430
441
431
- map_compute = transformation . compute
442
+ map_compute = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' compute' ]
432
443
compute_openmpscc = map_compute ['OpenMPSingleColumn' ]
433
444
434
445
test_compute = """
@@ -490,12 +501,13 @@ def test_parallel_routine_dispatch_compute_openmpscc(here, frontend):
490
501
def test_parallel_routine_dispatch_compute_openaccscc (here , frontend ):
491
502
492
503
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
504
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
493
505
routine = source ['dispatch_routine' ]
494
506
495
507
transformation = ParallelRoutineDispatchTransformation ()
496
- transformation .apply (source ['dispatch_routine' ])
508
+ transformation .apply (source ['dispatch_routine' ], item = item )
497
509
498
- map_compute = transformation . compute
510
+ map_compute = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' compute' ]
499
511
compute_openaccscc = map_compute ['OpenACCSingleColumn' ]
500
512
501
513
test_compute = """
@@ -576,12 +588,13 @@ def test_parallel_routine_dispatch_compute_openaccscc(here, frontend):
576
588
def test_parallel_routine_dispatch_variables (here , frontend ):
577
589
578
590
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
591
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
579
592
routine = source ['dispatch_routine' ]
580
593
581
594
transformation = ParallelRoutineDispatchTransformation ()
582
- transformation .apply (source ['dispatch_routine' ])
595
+ transformation .apply (source ['dispatch_routine' ], item = item )
583
596
584
- variables = transformation . dcls
597
+ variables = item . trafo_data [ 'create_parallel' ][ 'map_routine' ][ ' dcls' ]
585
598
586
599
test_variables = '''TYPE(CPG_BNDS_TYPE), INTENT(IN) :: YLCPG_BNDS
587
600
TYPE(STACK) :: YLSTACK
@@ -597,12 +610,13 @@ def test_parallel_routine_dispatch_imports(here, frontend):
597
610
#TODO : add imports to _parallel routines
598
611
599
612
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
613
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
600
614
routine = source ['dispatch_routine' ]
601
615
602
616
transformation = ParallelRoutineDispatchTransformation ()
603
- transformation .apply (source ['dispatch_routine' ])
617
+ transformation .apply (source ['dispatch_routine' ], item = item )
604
618
605
- imports = transformation . imports
619
+ imports = item . trafo_data [ 'create_parallel' ][ 'map_routine' ][ ' imports' ]
606
620
607
621
test_imports = """
608
622
USE ACPY_MOD
@@ -621,11 +635,12 @@ def test_parallel_routine_dispatch_new_callee_imports(here, frontend):
621
635
#TODO : add imports to _parallel routines
622
636
623
637
source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
638
+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
624
639
routine = source ['dispatch_routine' ]
625
640
626
641
transformation = ParallelRoutineDispatchTransformation ()
627
- transformation .apply (source ['dispatch_routine' ])
642
+ transformation .apply (source ['dispatch_routine' ], item = item )
628
643
629
- imports = transformation . callee_imports
644
+ imports = item . trafo_data [ 'create_parallel' ][ 'map_routine' ][ ' callee_imports' ]
630
645
631
646
assert fgen (imports ) == '#include "cpphinp_openacc.intfb.h"'
0 commit comments