|
5 | 5 | # granted to it by virtue of its status as an intergovernmental organisation
|
6 | 6 | # nor does it submit to any jurisdiction.
|
7 | 7 |
|
| 8 | +from loki.expression import symbols as sym |
| 9 | +from loki.ir import ( |
| 10 | + is_loki_pragma, get_pragma_parameters, pragma_regions_attached, |
| 11 | + FindNodes, nodes as ir |
| 12 | +) |
8 | 13 | from loki.transform import Transformation
|
9 | 14 |
|
10 | 15 | __all__ = ['ParallelRoutineDispatchTransformation']
|
11 | 16 |
|
12 | 17 |
|
13 | 18 | class ParallelRoutineDispatchTransformation(Transformation):
|
14 | 19 |
|
15 |
| - def __init__(self): |
16 |
| - self.dummy_return_value = [] |
17 |
| - |
18 | 20 | def transform_subroutine(self, routine, **kwargs):
|
19 |
| - self.dummy_return_value += [routine.name.lower()] |
| 21 | + with pragma_regions_attached(routine): |
| 22 | + for region in FindNodes(ir.PragmaRegion).visit(routine.body): |
| 23 | + if is_loki_pragma(region.pragma): |
| 24 | + self.process_parallel_region(routine, region) |
| 25 | + |
| 26 | + def process_parallel_region(self, routine, region): |
| 27 | + pragma_content = region.pragma.content.split(maxsplit=1) |
| 28 | + pragma_content = [entry.split('=', maxsplit=1) for entry in pragma_content[1].split(',')] |
| 29 | + pragma_attrs = { |
| 30 | + entry[0].lower(): entry[1] if len(entry) == 2 else None |
| 31 | + for entry in pragma_content |
| 32 | + } |
| 33 | + if 'parallel' not in pragma_attrs: |
| 34 | + return |
| 35 | + |
| 36 | + dr_hook_calls = self.create_dr_hook_calls( |
| 37 | + routine, pragma_attrs['name'], |
| 38 | + sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine) |
| 39 | + ) |
| 40 | + |
| 41 | + region.prepend(dr_hook_calls[0]) |
| 42 | + region.append(dr_hook_calls[1]) |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + def create_dr_hook_calls(scope, cdname, pkey): |
| 46 | + dr_hook_calls = [] |
| 47 | + for kswitch in (0, 1): |
| 48 | + call_stmt = ir.CallStatement( |
| 49 | + name=sym.Variable(name='DR_HOOK', scope=scope), |
| 50 | + arguments=(sym.StringLiteral(cdname), sym.IntLiteral(kswitch), pkey) |
| 51 | + ) |
| 52 | + dr_hook_calls += [ |
| 53 | + ir.Conditional( |
| 54 | + condition=sym.Variable(name='LHOOK', scope=scope), |
| 55 | + inline=True, body=(call_stmt,) |
| 56 | + ) |
| 57 | + ] |
| 58 | + return dr_hook_calls |
0 commit comments