Skip to content

Commit

Permalink
Add data_tests/saved__backend__/jax
Browse files Browse the repository at this point in the history
  • Loading branch information
paugier committed Jun 1, 2024
1 parent 0fdc802 commit d01277d
Show file tree
Hide file tree
Showing 20 changed files with 319 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const = 1
from __ext__MyClass2__exterior_import_boost_2 import func_import_2
import numpy as np


def func_import():
return const + func_import_2() + np.pi - np.pi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const = 1


def func_import_2():
return const
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const = 1
from __ext__func__exterior_import_boost_2 import func_import_2
import numpy as np


def func_import():
return const + func_import_2() + np.pi - np.pi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const = 1


def func_import_2():
return const
19 changes: 19 additions & 0 deletions data_tests/saved__backend__/jax/add_inline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# __protected__ from jax import jit
# __protected__ @jit


def add(a, b):
return a + b


# __protected__ @jit


def use_add(n=10000):
tmp = 0
for _ in range(n):
tmp = add(tmp, 1)
return tmp


__transonic__ = "0.6.4"
9 changes: 9 additions & 0 deletions data_tests/saved__backend__/jax/assign_func_boost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# __protected__ from jax import jit
# __protected__ @jit


def func(x):
return x**2


__transonic__ = "0.6.4"
18 changes: 18 additions & 0 deletions data_tests/saved__backend__/jax/block_fluidsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# __protected__ from jax import jit
# __protected__ @jit


def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt):
# transonic block (
# complex128[][][] state_spect_n12, state_spect,
# tendencies_n;
# float64[][] diss2;
# float dt
# )
state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2


arguments_blocks = {
"rk2_step0": ["state_spect_n12", "state_spect", "tendencies_n", "diss2", "dt"]
}
__transonic__ = "0.6.4"
19 changes: 19 additions & 0 deletions data_tests/saved__backend__/jax/blocks_type_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# __protected__ from jax import jit
# __protected__ @jit


def block0(a, b, n):
# transonic block (
# A a; A1 b;
# int n
# )
# transonic block (
# int[:] a, b;
# float n
# )
result = a**2 + b.mean() ** 3 + n
return result


arguments_blocks = {"block0": ["a", "b", "n"]}
__transonic__ = "0.6.4"
13 changes: 13 additions & 0 deletions data_tests/saved__backend__/jax/boosted_class_use_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# __protected__ from jax import jit
import jax.numpy as np
from __ext__MyClass2__exterior_import_boost import func_import

# __protected__ @jit


def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg):
return self_attr1 + self_attr0 + np.abs(arg) + func_import()


__code_new_method__MyClass2__myfunc = "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n"
__transonic__ = "0.6.4"
12 changes: 12 additions & 0 deletions data_tests/saved__backend__/jax/boosted_func_use_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# __protected__ from jax import jit
import jax.numpy as np
from __ext__func__exterior_import_boost import func_import

# __protected__ @jit


def func(a, b):
return (a * np.log(b)).max() + func_import()


__transonic__ = "0.6.4"
44 changes: 44 additions & 0 deletions data_tests/saved__backend__/jax/class_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def block0(a, b, n):
# foo
# transonic block (
# float[][] a, b;
# int n
# ) bar
# foo
# transonic block (
# float[][][] a, b;
# int n
# )
# foobar
result = np.zeros_like(a)
for _ in range(n):
result += a**2 + b**3
return result


# __protected__ @jit


def block1(a, b, n):
# transonic block (
# float[][] a, b;
# int n
# )
# transonic block (
# float[][][] a, b;
# int n
# )
result = np.zeros_like(a)
for _ in range(n):
result += a**2 + b**3
return result


arguments_blocks = {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]}
__transonic__ = "0.6.4"
20 changes: 20 additions & 0 deletions data_tests/saved__backend__/jax/class_rec_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# __protected__ from jax import jit
# __protected__ @jit


def __for_method__Myclass__func(self_attr, self_attr2, arg):
if __for_method__Myclass__func(self_attr, self_attr2, arg - 1) < 1:
return 1
else:
a = __for_method__Myclass__func(
self_attr, self_attr2, arg - 1
) * __for_method__Myclass__func(self_attr, self_attr2, arg - 1)
return (
a
+ self_attr * self_attr2 * arg
+ __for_method__Myclass__func(self_attr, self_attr2, arg - 1)
)


__code_new_method__Myclass__func = "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n"
__transonic__ = "0.6.4"
11 changes: 11 additions & 0 deletions data_tests/saved__backend__/jax/classic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def func(a, b):
return (a * np.log(b)).max()


__transonic__ = "0.6.4"
10 changes: 10 additions & 0 deletions data_tests/saved__backend__/jax/default_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# __protected__ from jax import jit
# __protected__ @jit


def func(a=1, b=None, c=1.0):
print(b)
return a + c


__transonic__ = "0.6.4"
13 changes: 13 additions & 0 deletions data_tests/saved__backend__/jax/methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def __for_method__Transmitter____call__(self_arr, self_freq, inp):
"""My docstring"""
return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr)


__code_new_method__Transmitter____call__ = "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n"
__transonic__ = "0.6.4"
18 changes: 18 additions & 0 deletions data_tests/saved__backend__/jax/mixed_classic_type_hint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def func(a, b):
return (a * np.log(b)).max()


# __protected__ @jit


def func1(a, b):
return a * np.cos(b)


__transonic__ = "0.6.4"
16 changes: 16 additions & 0 deletions data_tests/saved__backend__/jax/no_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# __protected__ from jax import jit
# __protected__ @jit


def func():
return 1


# __protected__ @jit


def func2():
return 1


__transonic__ = "0.6.4"
27 changes: 27 additions & 0 deletions data_tests/saved__backend__/jax/row_sum_boost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# __protected__ from jax import jit
import jax.numpy as np

# __protected__ @jit


def row_sum(arr, columns):
return arr.T[columns].sum(0)


# __protected__ @jit


def row_sum_loops(arr, columns):
# locals type annotations are used only for Cython
# arr.dtype not supported for memoryview
dtype = type(arr[0, 0])
res = np.empty(arr.shape[0], dtype=dtype)
for i in range(arr.shape[0]):
sum_ = dtype(0)
for j in range(columns.shape[0]):
sum_ += arr[i, columns[j]]
res[i] = sum_
return res


__transonic__ = "0.6.4"
33 changes: 33 additions & 0 deletions data_tests/saved__backend__/jax/subpackages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# __protected__ from jax import jit
from numpy.fft import rfft
from numpy.random import randn
from numpy.linalg import matrix_power
from scipy.special import jv

# __protected__ @jit


def test_np_fft(u):
u_fft = rfft(u)
return u_fft


# __protected__ @jit


def test_np_linalg_random(u):
(nx, ny) = u.shape
u[:] = randn(nx, ny)
u2 = u.T * u
u4 = matrix_power(u2, 2)
return u4


# __protected__ @jit


def test_sp_special(v, x):
return jv(v, x)


__transonic__ = "0.6.4"
13 changes: 13 additions & 0 deletions data_tests/saved__backend__/jax/type_hint_notemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# __protected__ from jax import jit
# __protected__ @jit


def compute(a, b, c, d, e):
print(e)
tmp = a + b
if 1 and 2:
tmp *= 2
return tmp


__transonic__ = "0.6.4"

0 comments on commit d01277d

Please sign in to comment.