|
| 1 | +import dataclasses as dc |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +import pytato as pt |
| 6 | +from pytools.obj_array import make_obj_array |
| 7 | + |
| 8 | +from arraycontext import ( |
| 9 | + Array, PytatoJAXArrayContext as BasePytatoJAXArrayContext, |
| 10 | + dataclass_array_container, with_container_arithmetic) |
| 11 | + |
| 12 | + |
| 13 | +Ncalls = 300 |
| 14 | + |
| 15 | + |
| 16 | +class PytatoJAXArrayContext(BasePytatoJAXArrayContext): |
| 17 | + def transform_dag(self, dag): |
| 18 | + # Test 1: Test that the number of untransformed call sites are as |
| 19 | + # expected |
| 20 | + assert pt.analysis.get_num_call_sites(dag) == Ncalls |
| 21 | + |
| 22 | + dag = pt.tag_all_calls_to_be_inlined(dag) |
| 23 | + print("[Pre-concatenation] Number of nodes =", |
| 24 | + pt.analysis.get_num_nodes(pt.inline_calls(dag))) |
| 25 | + dag = pt.concatenate_calls( |
| 26 | + dag, |
| 27 | + lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags |
| 28 | + ) |
| 29 | + |
| 30 | + # Test 2: Test that only one call-sites is left post concatentation |
| 31 | + assert pt.analysis.get_num_call_sites(dag) == 1 |
| 32 | + |
| 33 | + dag = pt.inline_calls(dag) |
| 34 | + print("[Post-concatenation] Number of nodes =", |
| 35 | + pt.analysis.get_num_nodes(dag)) |
| 36 | + |
| 37 | + return dag |
| 38 | + |
| 39 | + |
| 40 | +actx = PytatoJAXArrayContext() |
| 41 | + |
| 42 | + |
| 43 | +@with_container_arithmetic( |
| 44 | + bcast_obj_array=True, |
| 45 | + eq_comparison=False, |
| 46 | + rel_comparison=False, |
| 47 | +) |
| 48 | +@dataclass_array_container |
| 49 | +@dc.dataclass(frozen=True) |
| 50 | +class State: |
| 51 | + mass: Array |
| 52 | + vel: np.ndarray # np array of Arrays |
| 53 | + |
| 54 | + |
| 55 | +@actx.outline |
| 56 | +def foo(x1, x2): |
| 57 | + return (2*x1 + 3*x2 + x1**3 + x2**4 |
| 58 | + + actx.np.minimum(2*x1, 4*x2) |
| 59 | + + actx.np.maximum(7*x1, 8*x2)) |
| 60 | + |
| 61 | + |
| 62 | +rng = np.random.default_rng(0) |
| 63 | +Ndof = 10 |
| 64 | +Ndim = 3 |
| 65 | + |
| 66 | +results = [] |
| 67 | + |
| 68 | +for _ in range(Ncalls): |
| 69 | + Nel = rng.integers(low=4, high=17) |
| 70 | + state1_np = State( |
| 71 | + mass=rng.random((Nel, Ndof)), |
| 72 | + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), |
| 73 | + ) |
| 74 | + state2_np = State( |
| 75 | + mass=rng.random((Nel, Ndof)), |
| 76 | + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), |
| 77 | + ) |
| 78 | + |
| 79 | + state1 = actx.from_numpy(state1_np) |
| 80 | + state2 = actx.from_numpy(state2_np) |
| 81 | + results.append(foo(state1, state2)) |
| 82 | + |
| 83 | +actx.to_numpy(make_obj_array(results)) |
0 commit comments