Skip to content

Commit 30bc7cb

Browse files
kaushikcfdinducer
authored andcommitted
adds an outlining example
1 parent 40fde67 commit 30bc7cb

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

examples/how_to_outline.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import dataclasses as dc
2+
3+
import numpy as np
4+
5+
import pytato as pt
6+
from pytools.obj_array import make_obj_array
7+
8+
from arraycontext import (
9+
Array, PytatoJAXArrayContext as BasePytatoJAXArrayContext,
10+
dataclass_array_container, with_container_arithmetic)
11+
12+
13+
Ncalls = 300
14+
15+
16+
class PytatoJAXArrayContext(BasePytatoJAXArrayContext):
17+
def transform_dag(self, dag):
18+
# Test 1: Test that the number of untransformed call sites are as
19+
# expected
20+
assert pt.analysis.get_num_call_sites(dag) == Ncalls
21+
22+
dag = pt.tag_all_calls_to_be_inlined(dag)
23+
print("[Pre-concatenation] Number of nodes =",
24+
pt.analysis.get_num_nodes(pt.inline_calls(dag)))
25+
dag = pt.concatenate_calls(
26+
dag,
27+
lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags
28+
)
29+
30+
# Test 2: Test that only one call-sites is left post concatentation
31+
assert pt.analysis.get_num_call_sites(dag) == 1
32+
33+
dag = pt.inline_calls(dag)
34+
print("[Post-concatenation] Number of nodes =",
35+
pt.analysis.get_num_nodes(dag))
36+
37+
return dag
38+
39+
40+
actx = PytatoJAXArrayContext()
41+
42+
43+
@with_container_arithmetic(
44+
bcast_obj_array=True,
45+
eq_comparison=False,
46+
rel_comparison=False,
47+
)
48+
@dataclass_array_container
49+
@dc.dataclass(frozen=True)
50+
class State:
51+
mass: Array
52+
vel: np.ndarray # np array of Arrays
53+
54+
55+
@actx.outline
56+
def foo(x1, x2):
57+
return (2*x1 + 3*x2 + x1**3 + x2**4
58+
+ actx.np.minimum(2*x1, 4*x2)
59+
+ actx.np.maximum(7*x1, 8*x2))
60+
61+
62+
rng = np.random.default_rng(0)
63+
Ndof = 10
64+
Ndim = 3
65+
66+
results = []
67+
68+
for _ in range(Ncalls):
69+
Nel = rng.integers(low=4, high=17)
70+
state1_np = State(
71+
mass=rng.random((Nel, Ndof)),
72+
vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]),
73+
)
74+
state2_np = State(
75+
mass=rng.random((Nel, Ndof)),
76+
vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]),
77+
)
78+
79+
state1 = actx.from_numpy(state1_np)
80+
state2 = actx.from_numpy(state2_np)
81+
results.append(foo(state1, state2))
82+
83+
actx.to_numpy(make_obj_array(results))

0 commit comments

Comments
 (0)