-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update jax array api #7
Conversation
I reformatted the code on main and in this branch using yapf to make it easier to review the bits that have changed |
I'm still working through it, but some top level comments:
|
548e910
to
2333fd7
Compare
d1b9683
to
0920183
Compare
e33d034
to
8989a8a
Compare
benchmarks/scripts/README.md
Outdated
# To use on Jean Zay with the TKC project | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README file should be moved to the parent dir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved benchmarks to other repo.
Tell me if you still want the benchmark code to be here
I can add an 'example' folder instead
benchmarks/scripts/plotter.py
Outdated
sorted_df = pd.DataFrame(sorted_dfs) | ||
label = f"{method}-{backend}-{group['nodes'].values[0]}nodes" if nodes_in_label else f"{method}-{backend}" | ||
|
||
ax.plot(sorted_df['x'].values, sorted_df['time'], marker='o', linestyle='-', label=label) | ||
|
||
# add title nb of gpus | ||
ax.set_title(f"Number of GPUs: {gpu_size}") | ||
# Set x-axis ticks exactly at the provided data size values | ||
ax.set_yscale('log') | ||
data_sizes = list(dict.fromkeys(data_sizes)) | ||
data_sizes.sort() | ||
ax.set_xticks(data_sizes) | ||
ax.set_yticks([1e-3, 1e-2, 1e-1, 1e0]) | ||
|
||
# Set labels and title | ||
ax.set_xlabel('Data Size') | ||
ax.set_ylabel('Time') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI if you are not aware yet pandas has its own plotting API on top of matplotlib so you can do the same without all the burden once your dataframe is ready
https://pandas.pydata.org/pandas-docs/stable/user_guide/visualization.html
# find a way to finalize pytest | ||
def test_end(): | ||
# Make sure that it is cleaned up | ||
# This has to be this way because pytest runs the "global code" before running the tests | ||
# There are other solutions https://stackoverflow.com/questions/41871278/pytest-run-a-function-at-the-end-of-the-tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about defining some fixtures with scope=session
in a conftest.py ?
This could also be useful for the initialize at the beginning..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
session is multiple pytest files, no?
I am not sure how pytest works.
But clean up needs to be done after every test
jaxdecomp/fft.py
Outdated
@@ -1,9 +1,11 @@ | |||
import jax.numpy as jnp | |||
from jax._src.numpy.fft import _fft_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_fft_norm
is redefined below (code copied from jax source)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, can we remove the code below and just use this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 3786c77
README.md
Outdated
### Step II: Building jaxDecomp | ||
|
||
This step is easier :-) From this directory, just run the following | ||
From this directory, just run the followg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a few questions, but it looks pretty good!
If I understand correctly, right now there is a limitation that things will only work for a 3D cube, is this the same issue as the one I was too lazy to fix here:
Only transpose operations that preserve the size of local slices are supported
#1
tests/test_fft.py
Outdated
|
||
# Check the forward FFT | ||
assert_allclose(global_karray_slice.real, karray.real, atol=1e-10) | ||
assert_allclose(global_karray_slice.imag, karray.imag, atol=1e-10) | ||
# Non cube shape .. need to think about the slice coming from JAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hummmm what does this mean? it doesn't work if the array has abritrary shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does. (this is just a testing issue not a production code issue)
Issue is that pdims is [Z][X]
the code before used pdims as pdims[1] pdims[0] .. which obviously works because 2x4 and 4x2 will give the right FFT.
How ever when we construct the global array dims I use pdims for the partition specs and in this case they have to be accurate.
IE : (global is 512 512 512) a 2x4 pdims will give in the old code (256 128 512)
for cuDecomp pdims[0] is the Z axis, pdims[1] is the Y axis so If I do this :
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is not about pfft, it is about JAX fft and the tranpose needed
pfft works on abritrary shapes
Just testing against the array is not working the shapes of global_karray_slice are weird
tests/test_fft.py
Outdated
global_shape[1] // pdims[1]].imag, | ||
karray.imag, | ||
atol=1e-10) | ||
#assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we removing those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape of JAX arrays are wrong in the test.
This is not pfft error but just a test error
The FFT gives the right results
667022d
to
2d6552e
Compare
6dde8e4
to
5005851
Compare
The issue is the fake Maybe doing I will fix once you finish reviewing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be working :-)
jaxdecomp/fft.py
Outdated
|
||
# Has to be jitted here because _fft_norm will act on non fully addressable global array | ||
# Which means this should be jit wrapped | ||
@partial(jit, static_argnums=(0,1,3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ASKabalan , same question here
If you are talking about |
I meant everything seems to run (minus details noted above). I've removed all referenced to mpi4py because it's a pain to install, and is not necessary now that jax understands distribution itself. I'm going to go ahead and merge, we can treat the rest of the comments as small pull requests. This branch is way better than the main branch already. |
Before merging, please see my comment about the |
hummm what's the comment? |
You must use |
Maybe you haven't pushed your review? I can't see a comment at that line |
## Usage | ||
|
||
The API is still under development, so it doesn't look very streamlined, but you | ||
can already do the following: | ||
```python | ||
from mpi4py import MPI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WORLD communicator must be set in order for jax.distributed
to work, without it it won't work
Ex :
import jax
#from mpi4py import MPI <-- without this, it doesn't work
jax.distributed.initialize()
print(jax.devices())
@EiffL if you want to remove mpi4py
as a requirement, we will have to init the WORLD communicator ourselfs .. I would have to recreate the init function that just instantiate the singleton like this
void init(){jd::GridDescriptorManager::getInstance();};
Then we do this
import jax
import jaxdecomp
jaxdecomp.init()
jax.distributed.initialize()
print(jax.devices())
Ok for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may work for you, but not on JeanZay where we have to use srun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, we shouldn't need to init things ourselves for jax.distributed to work.
jaxdecomp/_src/padding.py
Outdated
interiour_padding = 0 | ||
|
||
# unlike jnp.pad lax.pad can unpad if given negative values | ||
return lax.pad(arr, padding_value=0.0, padding_config=((-first_x, -last_x , interiour_padding), (-first_y, -last_y , interiour_padding), (-first_z, -last_z , interiour_padding))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, The padding_value
has to be the same type as the array
Or maybe just using [first:-last]
slicing
What do you think @EiffL
Ah ok, I see your comments now.
I have no problem running this with mpirun, jax.distributed works as expected, with no mpi4py. Where do you see an issue? |
when I use I get this
Unless a communicator is set. |
Note that I'm instantiating the MPI context here: jaxDecomp/src/grid_descriptor_mgr.cc Line 26 in aededef
|
This is lazily executed at the first PFFT execution. I will try to recreate the init and tell you My comment explains it |
Yup, can you show me exactly the slurm config you use when you have this issue? |
Sure, but jax.distributed does not depend on mpi4py, we should not need it. (everything (all tests and all scripts) work fine on my machine without mpi4py) |
I think I force pushed on your commit .. |
e870541
to
d4a2b1c
Compare
no no, we don't need to init MPI comm, send me your slurm script, I'll modify it, it's just a matter of gpu affinity with tasks. |
Slack? |
d4a2b1c
to
e870541
Compare
This is a PR led by @ASKabalan to update the project to the JAX v0.4+ API