Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax array api #7

Merged
merged 29 commits into from
Apr 28, 2024
Merged

Update jax array api #7

merged 29 commits into from
Apr 28, 2024

Conversation

EiffL
Copy link
Member

@EiffL EiffL commented Mar 1, 2024

This is a PR led by @ASKabalan to update the project to the JAX v0.4+ API

.gitmodules Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
benchmarks/autotune/autotune_4_256.txt Outdated Show resolved Hide resolved
benchmarks/jaxpm.py Outdated Show resolved Hide resolved
@EiffL
Copy link
Member Author

EiffL commented Mar 2, 2024

I reformatted the code on main and in this branch using yapf to make it easier to review the bits that have changed

@EiffL
Copy link
Member Author

EiffL commented Mar 2, 2024

I'm still working through it, but some top level comments:

  • We don't want to use a fork of cuDecomp, we should use the upstream version. If we need to add something to it, we can always submit a PR to the upstream project.

  • The repo for the library itself should be kept clean and tidy, so we don't want to commit here the results of experimentation runs. We can have a separate repo where we save the results of experiments. It is fine however to have the instrumentation scripts here.

  • As a library, jaxDecomp should not infringe on API boundaries. By that I mean, the scope of jaxDecomp should be strictly limited to providing fft ops as close as possible to drop in replacement for the official ones. It is out of scope to define utility functions like map_global_array_from_slices. Of course we can define them and use them internally if we need them, but the end users should not have to use them, nor be told about them, because it would impact the way they write their scripts beyond the strict minimum of having access to distributed ops.
    The point here is to maintain clean interfaces between what different libraries provides, and have as minimal of a footprint as possible. The bigger the API we provide, the more ways it can break down when JAX changes and the costlier it is to maintain.

@ASKabalan ASKabalan force-pushed the update-jax-array-api branch from 548e910 to 2333fd7 Compare March 4, 2024 11:11
@ASKabalan ASKabalan force-pushed the update-jax-array-api branch 3 times, most recently from d1b9683 to 0920183 Compare March 12, 2024 19:01
@ASKabalan ASKabalan force-pushed the update-jax-array-api branch 5 times, most recently from e33d034 to 8989a8a Compare March 17, 2024 00:50
README.md Outdated Show resolved Hide resolved
README.md Show resolved Hide resolved
README.md Show resolved Hide resolved
Comment on lines 1 to 2
# To use on Jean Zay with the TKC project

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

README file should be moved to the parent dir

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved benchmarks to other repo.
Tell me if you still want the benchmark code to be here
I can add an 'example' folder instead

Comment on lines 179 to 195
sorted_df = pd.DataFrame(sorted_dfs)
label = f"{method}-{backend}-{group['nodes'].values[0]}nodes" if nodes_in_label else f"{method}-{backend}"

ax.plot(sorted_df['x'].values, sorted_df['time'], marker='o', linestyle='-', label=label)

# add title nb of gpus
ax.set_title(f"Number of GPUs: {gpu_size}")
# Set x-axis ticks exactly at the provided data size values
ax.set_yscale('log')
data_sizes = list(dict.fromkeys(data_sizes))
data_sizes.sort()
ax.set_xticks(data_sizes)
ax.set_yticks([1e-3, 1e-2, 1e-1, 1e0])

# Set labels and title
ax.set_xlabel('Data Size')
ax.set_ylabel('Time')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI if you are not aware yet pandas has its own plotting API on top of matplotlib so you can do the same without all the burden once your dataframe is ready
https://pandas.pydata.org/pandas-docs/stable/user_guide/visualization.html

slurms/HOWTORUN.md Outdated Show resolved Hide resolved
tests/test_fft.py Outdated Show resolved Hide resolved
Comment on lines +223 to +189
# find a way to finalize pytest
def test_end():
# Make sure that it is cleaned up
# This has to be this way because pytest runs the "global code" before running the tests
# There are other solutions https://stackoverflow.com/questions/41871278/pytest-run-a-function-at-the-end-of-the-tests

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about defining some fixtures with scope=session in a conftest.py ?

This could also be useful for the initialize at the beginning..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

session is multiple pytest files, no?
I am not sure how pytest works.
But clean up needs to be done after every test

jaxdecomp/fft.py Outdated
@@ -1,9 +1,11 @@
import jax.numpy as jnp
from jax._src.numpy.fft import _fft_norm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fft_norm is redefined below (code copied from jax source)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, can we remove the code below and just use this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 3786c77

README.md Outdated
### Step II: Building jaxDecomp

This step is easier :-) From this directory, just run the following
From this directory, just run the followg

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few questions, but it looks pretty good!

If I understand correctly, right now there is a limitation that things will only work for a 3D cube, is this the same issue as the one I was too lazy to fix here:

Only transpose operations that preserve the size of local slices are supported
#1

.clang-format Outdated Show resolved Hide resolved
CMakeLists.txt Show resolved Hide resolved
benchmarks/mpi4jafft.py Outdated Show resolved Hide resolved
jaxdecomp/_src/fft.py Outdated Show resolved Hide resolved
jaxdecomp/_src/fft.py Outdated Show resolved Hide resolved
tests/test_fft.py Outdated Show resolved Hide resolved

# Check the forward FFT
assert_allclose(global_karray_slice.real, karray.real, atol=1e-10)
assert_allclose(global_karray_slice.imag, karray.imag, atol=1e-10)
# Non cube shape .. need to think about the slice coming from JAX
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hummmm what does this mean? it doesn't work if the array has abritrary shape?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does. (this is just a testing issue not a production code issue)
Issue is that pdims is [Z][X]
the code before used pdims as pdims[1] pdims[0] .. which obviously works because 2x4 and 4x2 will give the right FFT.
How ever when we construct the global array dims I use pdims for the partition specs and in this case they have to be accurate.

IE : (global is 512 512 512) a 2x4 pdims will give in the old code (256 128 512)
for cuDecomp pdims[0] is the Z axis, pdims[1] is the Y axis so If I do this :

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is not about pfft, it is about JAX fft and the tranpose needed
pfft works on abritrary shapes
Just testing against the array is not working the shapes of global_karray_slice are weird

global_shape[1] // pdims[1]].imag,
karray.imag,
atol=1e-10)
#assert_allclose(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we removing those?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape of JAX arrays are wrong in the test.
This is not pfft error but just a test error
The FFT gives the right results

tests/test_fft.py Outdated Show resolved Hide resolved
tests/test_fft.py Outdated Show resolved Hide resolved
@EiffL EiffL marked this pull request as ready for review March 21, 2024 04:40
@ASKabalan ASKabalan force-pushed the update-jax-array-api branch from 667022d to 2d6552e Compare March 21, 2024 08:25
@ASKabalan ASKabalan force-pushed the update-jax-array-api branch 2 times, most recently from 6dde8e4 to 5005851 Compare March 21, 2024 16:55
@ASKabalan
Copy link
Collaborator

ASKabalan commented Apr 27, 2024

File "/mnt/home/flanusse/repo/jaxDecomp/scripts/demo.py", line 58, in
recarray = slice_unpad(exchanged_reduced, padding_width, pdims)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: TypeError: pad operand and padding_value must be same dtype: got complex64 and float32.

The issue is the fake padding_value that I give, it should be the same type as the array
Issue is here

Maybe doing arr[first:-last] is just better

I will fix once you finish reviewing

Copy link
Member Author

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be working :-)

jaxdecomp/fft.py Outdated

# Has to be jitted here because _fft_norm will act on non fully addressable global array
# Which means this should be jit wrapped
@partial(jit, static_argnums=(0,1,3))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ASKabalan , same question here

@ASKabalan
Copy link
Collaborator

Seems to be working :-)

If you are talking about from mpi4py import MPI before jax.distributed() for me it doesn't
I use srun and not mpirun to run code, what do you use?

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

I meant everything seems to run (minus details noted above).

I've removed all referenced to mpi4py because it's a pain to install, and is not necessary now that jax understands distribution itself.

I'm going to go ahead and merge, we can treat the rest of the comments as small pull requests. This branch is way better than the main branch already.

@ASKabalan
Copy link
Collaborator

Before merging, please see my comment about the distribted global_state
here

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

hummm what's the comment?

@ASKabalan
Copy link
Collaborator

You must use jax.process_index() instead global_state.rank

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

Maybe you haven't pushed your review? I can't see a comment at that line

README.md Outdated Show resolved Hide resolved
benchmarks/autotune/autotune_4_256.txt Outdated Show resolved Hide resolved
benchmarks/jaxpm.py Outdated Show resolved Hide resolved
.clang-format Outdated Show resolved Hide resolved
CMakeLists.txt Show resolved Hide resolved
CMakeLists.txt Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
## Usage

The API is still under development, so it doesn't look very streamlined, but you
can already do the following:
```python
from mpi4py import MPI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WORLD communicator must be set in order for jax.distributed to work, without it it won't work

Ex :

import jax
#from mpi4py import MPI <-- without this, it doesn't work

jax.distributed.initialize()
print(jax.devices())

@EiffL if you want to remove mpi4py as a requirement, we will have to init the WORLD communicator ourselfs .. I would have to recreate the init function that just instantiate the singleton like this

void init(){jd::GridDescriptorManager::getInstance();};

Then we do this

import jax
import jaxdecomp

jaxdecomp.init()
jax.distributed.initialize()

print(jax.devices())

Ok for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may work for you, but not on JeanZay where we have to use srun

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we shouldn't need to init things ourselves for jax.distributed to work.

interiour_padding = 0

# unlike jnp.pad lax.pad can unpad if given negative values
return lax.pad(arr, padding_value=0.0, padding_config=((-first_x, -last_x , interiour_padding), (-first_y, -last_y , interiour_padding), (-first_z, -last_z , interiour_padding)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, The padding_value has to be the same type as the array

Or maybe just using [first:-last] slicing
What do you think @EiffL

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

Ah ok, I see your comments now.

WORLD communicator must be set in order for jax.distributed to work, without it it won't work

I have no problem running this with mpirun, jax.distributed works as expected, with no mpi4py. Where do you see an issue?

@ASKabalan
Copy link
Collaborator

ASKabalan commented Apr 27, 2024

when I use srun

I get this

CUDA backend failed to initialize: jaxlib/cuda/versions_helpers.cc:90: operation gpuDeviceGetAttribute( &major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device) failed: CUDA_ERROR_NOT_INITIALIZED (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Unless a communicator is set.
I can quickly recreate the init function and test

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

Note that I'm instantiating the MPI context here:

CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr));

@ASKabalan
Copy link
Collaborator

This is lazily executed at the first PFFT execution.
Too late for that.

I will try to recreate the init and tell you

My comment explains it

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

Yup, can you show me exactly the slurm config you use when you have this issue?

@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

Sure, but jax.distributed does not depend on mpi4py, we should not need it.

(everything (all tests and all scripts) work fine on my machine without mpi4py)

@ASKabalan
Copy link
Collaborator

ASKabalan commented Apr 27, 2024

I think I force pushed on your commit ..

@ASKabalan ASKabalan force-pushed the update-jax-array-api branch from e870541 to d4a2b1c Compare April 27, 2024 23:26
@EiffL
Copy link
Member Author

EiffL commented Apr 27, 2024

no no, we don't need to init MPI comm, send me your slurm script, I'll modify it, it's just a matter of gpu affinity with tasks.

@ASKabalan
Copy link
Collaborator

no no, we don't need to init MPI comm, send me your slurm script, I'll modify it, it's just a matter of gpu affinity with tasks.

Slack?

@EiffL EiffL force-pushed the update-jax-array-api branch from d4a2b1c to e870541 Compare April 28, 2024 00:27
@EiffL EiffL merged commit 68d209d into main Apr 28, 2024
@ASKabalan ASKabalan deleted the update-jax-array-api branch July 4, 2024 17:53
@ASKabalan ASKabalan restored the update-jax-array-api branch July 4, 2024 17:53
@ASKabalan ASKabalan deleted the update-jax-array-api branch December 21, 2024 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants