Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use segment_sum for mv_multiply #1

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

Conversation

RobinKa
Copy link
Owner

@RobinKa RobinKa commented Nov 6, 2021

  • Previously was using a loop and add at index which gets unrolled, now using segment_sum to sum same output indices
  • 10x faster JIT on CPU, 6x faster JIT on GPU
  • 100x slower runtime on CPU, 5x faster runtime on GPU

Should maybe add a flag for whether to use this one, very useful for large algebras where JIT takes very long because of analyzing the unrolled loop. Maybe make it the default on GPU too.

CPU results show segment_sum runtime very dependent on batch size

a_val, a_ind = jnp.array(jnp.ones([5, 10]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([5, 10]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 94 ms
10.8 µs ± 301 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

old:
Wall time: 1.04 s
10.9 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

---
a_val, a_ind = jnp.array(jnp.ones([5, 100]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([5, 100]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 227 ms
48.6 µs ± 640 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

old:
Wall time: 676 ms
11.6 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
---
a_val, a_ind = jnp.array(jnp.ones([10, 100]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([10, 100]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 261 ms
49.1 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

old:
Wall time: 687 ms
11.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
---
a_val, a_ind = jnp.array(jnp.ones([10, 1000]), dtype=jnp.float32), tuple((i,) for i in range(5))
b_val, b_ind = jnp.array(jnp.ones([10, 1000]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))

new:
Wall time: 256 ms
1.19 ms ± 69.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

old:
Wall time: 558 ms
16.9 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@RobinKa RobinKa mentioned this pull request Dec 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant