Skip to content

Commit

Permalink
Jax: no @jit for 'internal' functions
Browse files Browse the repository at this point in the history
  • Loading branch information
paugier committed Aug 26, 2024
1 parent b8989ba commit 161a31b
Show file tree
Hide file tree
Showing 17 changed files with 9 additions and 67 deletions.
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/add_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,5 @@ def use_add(n=10000):
return tmp


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/assign_func_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,5 @@ def func(x):
return x**2


# __protected__ @jit


def __transonic__():
return "0.7.1"
6 changes: 0 additions & 6 deletions data_tests/saved__backend__/jax/block_fluidsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt):
state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2


# __protected__ @jit


def arguments_blocks():
return {
"rk2_step0": [
Expand All @@ -27,8 +24,5 @@ def arguments_blocks():
}


# __protected__ @jit


def __transonic__():
return "0.7.1"
6 changes: 0 additions & 6 deletions data_tests/saved__backend__/jax/blocks_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@ def block0(a, b, n):
return result


# __protected__ @jit


def arguments_blocks():
return {"block0": ["a", "b", "n"]}


# __protected__ @jit


def __transonic__():
return "0.7.1"
6 changes: 0 additions & 6 deletions data_tests/saved__backend__/jax/boosted_class_use_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,9 @@ def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg):
return self_attr1 + self_attr0 + np.abs(arg) + func_import()


# __protected__ @jit


def __code_new_method__MyClass2__myfunc():
return "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n"


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/boosted_func_use_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,5 @@ def func(a, b):
return (a * np.log(b)).max() + func_import()


# __protected__ @jit


def __transonic__():
return "0.7.1"
6 changes: 0 additions & 6 deletions data_tests/saved__backend__/jax/class_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,9 @@ def block1(a, b, n):
return result


# __protected__ @jit


def arguments_blocks():
return {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]}


# __protected__ @jit


def __transonic__():
return "0.7.1"
6 changes: 0 additions & 6 deletions data_tests/saved__backend__/jax/class_rec_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,9 @@ def __for_method__Myclass__func(self_attr, self_attr2, arg):
)


# __protected__ @jit


def __code_new_method__Myclass__func():
return "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n"


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,5 @@ def func(a, b):
return (a * np.log(b)).max()


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,5 @@ def func(a=1, b=None, c=1.0):
return a + c


# __protected__ @jit


def __transonic__():
return "0.7.1"
6 changes: 0 additions & 6 deletions data_tests/saved__backend__/jax/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,9 @@ def __for_method__Transmitter____call__(self_arr, self_freq, inp):
return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr)


# __protected__ @jit


def __code_new_method__Transmitter____call__():
return "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n"


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/mixed_classic_type_hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,5 @@ def func1(a, b):
return a * np.cos(b)


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/no_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,5 @@ def func2():
return 1


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/row_sum_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,5 @@ def row_sum_loops(arr, columns):
return res


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/subpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,5 @@ def test_sp_special(v, x):
return jv(v, x)


# __protected__ @jit


def __transonic__():
return "0.7.1"
3 changes: 0 additions & 3 deletions data_tests/saved__backend__/jax/type_hint_notemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,5 @@ def compute(a, b, c, d, e):
return tmp


# __protected__ @jit


def __transonic__():
return "0.7.1"
10 changes: 9 additions & 1 deletion src/transonic/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ def add_jax_comments(code):
node.module = "jax.numpy"

# Add JIT decorator
if isinstance(node, gast.FunctionDef):
if (
isinstance(node, gast.FunctionDef)
and node.name
not in (
"arguments_blocks",
"__transonic__",
)
and not node.name.startswith("__code_new_method__")
):
new_body.append(CommentLine("# __protected__ @jit"))
new_body.append(node)

Expand Down

0 comments on commit 161a31b

Please sign in to comment.