Skip to content

fix: attempt to improve test reliability. #1037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_jit_test(self, args, kwargs, jit_kwargs) -> None:
JIT compiling and executing the passed function, the latter is the time for
dispatching the JIT compiled function.
"""
wait_time = 2
wait_time = 0.2
trace_counter = Mock()

x_in = jnp.ones(1000)
Expand Down Expand Up @@ -298,43 +298,34 @@ def test_speed_comparison_test(self) -> None:
not cached, and that a for-looped variant of a mean function takes longer to
compile than the in-built vectorised version.
"""
wait_time = 0.2
trace_counter = Mock()
# Define a for-looped version of a mean computation, this will be very slow to
# compile.

def _slow_mean(a):
trace_counter()
num_points = a.shape[0]
total = 0
for i in range(num_points):
total += a[i]
return total / num_points
time.sleep(wait_time)
return jnp.mean(a)

random_vector = jr.normal(jr.key(2_024), shape=(100,))
a_in = jnp.ones(shape=(100,))

num_runs = 10
num_runs = 5
function_sequence = [
JITCompilableFunction(_slow_mean, fn_kwargs={"a": a_in}, name="slow_mean"),
JITCompilableFunction(jnp.mean, fn_kwargs={"a": a_in}, name="jnp_mean"),
]
summary_stats, result_dict = speed_comparison_test(
[
JITCompilableFunction(
_slow_mean, fn_kwargs={"a": random_vector}, name="slow_mean"
),
JITCompilableFunction(
jnp.mean, fn_kwargs={"a": random_vector}, name="jnp_mean"
),
],
num_runs=num_runs,
log_results=False,
function_sequence, num_runs=num_runs, log_results=False
)

# Tracing should occur for each run of the function, not just once, thus
# `trace_counter` should be called num_runs times.
# `trace_counter` should be called `num_runs` times.
assert trace_counter.call_count == num_runs

# Assert that indeed the mean compilation time of slow_mean is slower than
# jnp.mean.
# At trace time `time.sleep` will be called. Thus, we can be sure that,
# `slow_mean_compilation_time` is lower bounded by `jnp_mean_compilation_time`.
slow_mean_compilation_time = summary_stats[0][0][0]
fast_mean_compilation_time = summary_stats[1][0][0]
assert slow_mean_compilation_time > fast_mean_compilation_time > 0
jnp_mean_compilation_time = summary_stats[1][0][0]
assert slow_mean_compilation_time > jnp_mean_compilation_time > 0

# Check result dictionary has the correct size
assert len(result_dict["slow_mean"]) == num_runs
Expand Down
Loading