Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify BlockArray implementation #259

Merged
merged 37 commits into from
Apr 22, 2022
Merged

Conversation

Michael-T-McCann
Copy link
Contributor

@Michael-T-McCann Michael-T-McCann commented Mar 22, 2022

This is a work in progress PR to change the underlying implementation of BlockArray from a flattened DeviceArray to a list of DeviceArrays. The goal is to simplify the BlockArray implementation, reduce its coupling to jax internals, and provide additional functionality (mixed datatypes, blocks on different GPUs).

Closes #179. Also touches #237 #238 #159 #239.

Timing examples (best of 3 runs, total time, scripts edited to remove input, MacBook Pro, CPU):

time python examples/scripts/denoise_tv_iso_pgm.py > /dev/null 2>&1
old: 2.936, new 7.185,
w/ additional @jit.jit, old: 2.936, new: 2.923

time python examples/scripts/sparsecode_poisson_pgm.py > /dev/null 2>&1
old: 12.964, new 12.663

Timing simple ops

  • array with 10,000 blocks, %timeit -n 3 -r 3 x = snp.ones(10000*((2, 2),))
    • new
      2.03 s ± 26.8 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      62.7 ms ± 178 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
  • jitted creation
    f = jax.jit(lambda: snp.ones(10000*((2, 2),)))
    f() # trigger jit
    %timeit -n 3 -r 3 z = f()
    • new
      36.6 ms ± 345 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      33.3 µs ± 30.2 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
  • smaller creation
    %timeit -n 3 -r 3 x = snp.ones(5*((512, 512),))
    • new
      2.1 ms ± 176 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      1.38 ms ± 125 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
  • multiplying blocks
    x = snp.ones(512*((512, 512),))
    %timeit -n 3 -r 3 y = x @ x
    • new
      406 ms ± 60.5 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      2.21 s ± 55.7 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
  • reduction
    %timeit -n 3 -r 3 z = snp.linalg.norm(x)
    • new
      567 ms ± 50.5 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      152 ms ± 9.02 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
  • jited reduction
    f = jax.jit(lambda x: snp.linalg.norm(x))
    f(x) # trigger jit
    %timeit -n 3 -r 3 z = f(x)
    • new
      269 ms ± 1.6 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      147 ms ± 1.35 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
  • smaller reduction
    x = snp.ones(5*((512, 512),))
    %timeit -n 3 -r 3 z = snp.linalg.norm(x)
    • new
      3.86 ms ± 266 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • new (jit as above)
      1.63 ms ± 135 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old
      1.49 ms ± 203 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)
    • old (jit as above)
      1.65 ms ± 228 µs per loop (mean ± std. dev. of 3 runs, 3 loops each)

@bwohlberg bwohlberg added the improvement Improvement of existing code, including addressing of omissions or inconsistencies label Mar 23, 2022
@Michael-T-McCann Michael-T-McCann force-pushed the mike/BlockArray_tuple branch 4 times, most recently from ef3988b to 2efb2ae Compare April 14, 2022 18:16
@codecov
Copy link

codecov bot commented Apr 14, 2022

Codecov Report

Merging #259 (715711c) into main (a1838ee) will increase coverage by 0.20%.
The diff coverage is 98.25%.

@@            Coverage Diff             @@
##             main     #259      +/-   ##
==========================================
+ Coverage   93.86%   94.07%   +0.20%     
==========================================
  Files          51       49       -2     
  Lines        3701     3241     -460     
==========================================
- Hits         3474     3049     -425     
+ Misses        227      192      -35     
Flag Coverage Δ
unittests 94.07% <98.25%> (+0.20%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
scico/numpy/_wrappers.py 97.10% <97.10%> (ø)
scico/numpy/blockarray.py 97.33% <97.33%> (ø)
scico/_flax.py 96.92% <100.00%> (ø)
scico/_generic_operators.py 91.81% <100.00%> (ø)
scico/denoiser.py 89.02% <100.00%> (ø)
scico/functional/_dist.py 100.00% <100.00%> (ø)
scico/functional/_functional.py 90.41% <100.00%> (ø)
scico/functional/_indicator.py 100.00% <100.00%> (ø)
scico/functional/_norm.py 100.00% <100.00%> (ø)
scico/linop/_circconv.py 83.83% <100.00%> (ø)
... and 19 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a1838ee...715711c. Read the comment docs.

@@ -85,7 +85,7 @@ def __init__(
def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]:
if self.collapsable and self.collapse:
return snp.stack([op @ x for op in self.ops])
return BlockArray.array([op @ x for op in self.ops])
return BlockArray([op @ x for op in self.ops])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any opinions on the syntax snp.BlockArray(stuff) vs snp.blockarray(stuff)? The second one is little more like NumPy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd lean towards copying NumPy style, but this is worth a discussion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made both do the same thing, but tended to write snp.blockarray(...) wherever I could. By comparison, np.ndarray and np.array do different things, with np.ndarray documented as a low-level method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's OK, but perhaps add a note in the docs indicating that the snp.blockarray form is preferred.

scico/__init__.py Outdated Show resolved Hide resolved
scico/test/test_numpy.py Outdated Show resolved Hide resolved
@Michael-T-McCann Michael-T-McCann marked this pull request as ready for review April 20, 2022 18:46
@Michael-T-McCann Michael-T-McCann requested a review from tbalke April 21, 2022 18:06
def is_nested(x: Any) -> bool:
"""Check if input is a list/tuple containing at least one list/tuple.

Args:
x: Object to be tested.

Returns:
``True`` if `x` is a list/tuple of list/tuples, otherwise
``False``.
``True`` if `x` is a list/tuple of list/tuples, ``False`` otherwise.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean list/tuple of lists/tuples?

A = FiniteDifference(
input_shape=input_shape, input_dtype=input_dtype, axes=axes, append=append
)
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checkin, was this pass-cause left here with intention? And if so, why not use if not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. No intention, probably just quickly hacking the old test. Should be more readable now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improvement of existing code, including addressing of omissions or inconsistencies
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consider simplification or replacement of BlockArray
3 participants