diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index eb2217c1..800d4e79 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -418,7 +418,6 @@ Since this is an experimental feature, there are some limitations: - Kernel launch dimensions are inferred from the shape of the first argument. - Input arguments are followed by output arguments in the Warp kernel definition. - There must be at least one input argument and at least one output argument. - - Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument). - All arrays must be contiguous. - Only the CUDA backend is supported. @@ -462,6 +461,233 @@ Here is an example of an operation with three inputs and two outputs:: print(x) print(y) +Using shardmap for distributed computation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Warp can be used in conjunction with JAX's `shard_map `_ to perform distributed multi-GPU computations. + +To achieve this, the JAX distributed environment must be initialized (see `Distributed Arrays and Automatic Parallelization `_ for more details): + +.. code-block:: python + + import jax + jax.distributed.initialize() + +This initialization must be called at the beginning of your program, before any other JAX operations. + +Here's an example of how to use `shard_map` with a Warp kernel: + +.. code-block:: python + + import warp as wp + import jax + import jax.numpy as jnp + from jax.sharding import PartitionSpec as P + from jax.experimental.multihost_utils import process_allgather as allgather + from jax.experimental.shard_map import shard_map + from warp.jax_experimental import jax_kernel + import numpy as np + + # Initialize JAX distributed environment + jax.distributed.initialize() + num_gpus = jax.device_count() + + def print_on_process_0(*args, **kwargs): + if jax.process_index() == 0: + print(*args, **kwargs) + + print_on_process_0(f"Running on {num_gpus} GPU(s)") + + @wp.kernel + def multiply_by_two_kernel( + a_in: wp.array(dtype=wp.float32), + a_out: wp.array(dtype=wp.float32), + ): + index = wp.tid() + a_out[index] = a_in[index] * 2.0 + + jax_warp_multiply = jax_kernel(multiply_by_two_kernel) + + def warp_multiply(x): + result = jax_warp_multiply(x) + return result + + # a_in here is the full sharded array with shape (M,) + # The output will also be a sharded array with shape (M,) + def warp_distributed_operator(a_in): + def _sharded_operator(a_in): + # Inside the sharded operator, a_in is a local shard on each device + # If we have N devices and input size M, each shard has shape (M/N,) + + # warp_multiply applies the Warp kernel to the local shard + result = warp_multiply(a_in)[0] + + # result has the same shape as the input shard (M/N,) + return result + + # shard_map distributes the computation across devices + return shard_map( + _sharded_operator, + mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"), + in_specs=(P("x"),), # Input is sharded along the 'x' axis + out_specs=P("x"), # Output is also sharded along the 'x' axis + check_rep=False, + )(a_in) + + print_on_process_0("Test distributed multiplication using JAX + Warp") + + devices = jax.devices() + mesh = jax.sharding.Mesh(np.array(devices), "x") + sharding_spec = jax.sharding.NamedSharding(mesh, P("x")) + + input_size = num_gpus * 5 # 5 elements per device + single_device_arrays = jnp.arange(input_size, dtype=jnp.float32) + + # Define the shape of the input array based on the total input size + shape = (input_size,) + + # Create a list of arrays by distributing the single_device_arrays across the available devices + # Each device will receive a portion of the input data + arrays = [ + jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device + for d, index in sharding_spec.addressable_devices_indices_map(shape).items() + ] + + # Combine the individual device arrays into a single sharded array + sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays) + + # sharded_array has shape (input_size,) but is distributed across devices + print_on_process_0(f"Input array: {allgather(sharded_array)}") + + # warp_result has the same shape and sharding as sharded_array + warp_result = warp_distributed_operator(sharded_array) + + # allgather collects results from all devices, resulting in a full array of shape (input_size,) + print_on_process_0("Warp Output:", allgather(warp_result)) + +In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output. + +This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously. + +To run this program on multiple GPUs, you must have OpenMPI installed. You can consult the `OpenMPI installation guide `_ for instructions on how to install it. Once OpenMPI is installed, you can use `mpirun` with the following command: + +.. code-block:: bash + + mpirun -np python .py + + +Specifying launch dimensions for matrix operations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In some cases, particularly for matrix operations, it's necessary to specify the launch dimensions for Warp kernels. This is because the default behavior of inferring dimensions from the first argument may not always be suitable for matrix operations. Here's an example of a distributed matrix multiplication using Warp and JAX: + +.. code-block:: python + + import warp as wp + import jax + import jax.numpy as jnp + from jax.sharding import PartitionSpec as P + from jax.experimental.multihost_utils import process_allgather as allgather + from jax.experimental.shard_map import shard_map + from warp.jax_experimental import jax_kernel + import numpy as np + + jax.distributed.initialize() + num_gpus = jax.device_count() + + def print_on_process_0(*args, **kwargs): + if jax.process_index() == 0: + print(*args, **kwargs) + + print_on_process_0(f"Running on {num_gpus} GPU(s)") + + @wp.kernel + def matmul_kernel( + a: wp.array2d(dtype=wp.float32), + b: wp.array2d(dtype=wp.float32), + c: wp.array2d(dtype=wp.float32), + ): + # a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N) + i, j = wp.tid() + M = a.shape[0] # M/num_gpus + K = a.shape[1] # K + N = b.shape[1] # N + if i < M and j < N: + s = wp.float32(0.0) + for k in range(K): + s += a[i, k] * b[k, j] + c[i, j] = s + + # Specify launch dimensions based on the number of GPUs + def create_jax_warp_matmul(M, N): + # M: total rows, N: total columns + block_size_m = M // num_gpus # Rows per GPU + block_size_n = N # All columns + return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n)) + + def warp_distributed_matmul(a, b): + # a: (M, K) sharded across GPUs, b: (K, N) replicated + M, K = a.shape + _, N = b.shape + jax_warp_matmul = create_jax_warp_matmul(M, N) + + def _sharded_operator(a_shard, b): + # a_shard: (M/num_gpus, K), b: (K, N) + return jax_warp_matmul(a_shard, b)[0] # Result: (M/num_gpus, N) + + return shard_map( + _sharded_operator, + mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"), + in_specs=(P("x", None), P(None, None)), # a sharded in first dim, b replicated + out_specs=P("x", None), # Output sharded in first dim + check_rep=False, + )(a, b) + + print_on_process_0("Test distributed matrix multiplication using JAX + Warp") + + # Define matrix dimensions + M = 8 * num_gpus # Scale M with the number of devices + K, N = 4, 6 + + # Create input matrices + a = jnp.arange(M * K, dtype=jnp.float32).reshape(M, K) # Shape: (M, K) + b = jnp.arange(K * N, dtype=jnp.float32).reshape(K, N) # Shape: (K, N) + + devices = jax.devices() + mesh = jax.sharding.Mesh(np.array(devices), "x") + sharding_spec_a = jax.sharding.NamedSharding(mesh, P("x", None)) + sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None, None)) + + # Shard matrix A and replicate matrix B + sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device + replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices + + print_on_process_0(f"Input matrix A:\n{allgather(sharded_a)}") # Shape: (M, K) + print_on_process_0(f"Input matrix B:\n{allgather(replicated_b)}") # Shape: (K, N) + + warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device + print_on_process_0("Warp Output:") + # Use allgather to collect results from all devices + print_on_process_0(allgather(warp_result)) # Shape: (M, N) + + jax_result = jnp.matmul(a, b) # Shape: (M, N) + print_on_process_0("JAX Output:") + print_on_process_0(jax_result) + + expected_shape = (M, N) + print_on_process_0(f"Expected shape: {expected_shape}") + print_on_process_0(f"Warp output shape: {warp_result.shape}") # Should be (M/num_gpus, N) on each device + print_on_process_0(f"JAX output shape: {jax_result.shape}") # Should be (M, N) + + allclose = jnp.allclose(allgather(warp_result), jax_result, atol=1e-5) + print_on_process_0(f"Allclose: {allclose}") + +In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction. + +Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix. + +Note that this is a naive implementation of matrix multiplication for the sake of this illustration, and there are many optimizations that can be made to improve performance. + .. _DLPack: DLPack diff --git a/warp/jax_experimental.py b/warp/jax_experimental.py index c3e3d072..8e78ab26 100644 --- a/warp/jax_experimental.py +++ b/warp/jax_experimental.py @@ -21,17 +21,22 @@ _registered_kernel_to_id = {} -def jax_kernel(wp_kernel): +def jax_kernel(wp_kernel, launch_dims=None): """Create a Jax primitive from a Warp kernel. NOTE: This is an experimental feature under development. + Args: + wp_kernel: The Warp kernel to be wrapped. + launch_dims: Optional. Specify the kernel launch dimensions. If None, + dimensions are inferred from the shape of the first argument. + This option when set will specify the output dimensions. + Current limitations: - All kernel arguments must be arrays. - - Kernel launch dimensions are inferred from the shape of the first argument. + - If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument. - Input arguments are followed by output arguments in the Warp kernel definition. - There must be at least one input argument and at least one output argument. - - Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument). - All arrays must be contiguous. - Only the CUDA backend is supported. """ @@ -47,7 +52,7 @@ def jax_kernel(wp_kernel): id = _registered_kernel_to_id[wp_kernel] def bind(*args): - return _jax_warp_p.bind(*args, kernel=id) + return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims) return bind @@ -106,7 +111,7 @@ def _get_jax_device(): device = jax.config.jax_default_device # if default device is not set, use first device if device is None: - device = jax.devices()[0] + device = jax.local_devices()[0] return device @@ -223,12 +228,17 @@ def base_type_is_compatible(warp_type, jax_ir_type): raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}") # Abstract evaluation. - def jax_warp_abstract(*args, kernel=None): + def jax_warp_abstract(*args, kernel=None, launch_dims=None): wp_kernel = _registered_kernels[kernel] # All the extra arguments to the warp kernel are outputs. warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]] - # TODO. Let's just use the first input dimension to infer the output's dimensions. - dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape)) + + if launch_dims is None: + # Use the first input dimension to infer the output's dimensions if launch_dims is not provided + dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape)) + else: + dims = launch_dims + jax_outputs = [] for o in warp_outputs: shape = list(dims) + list(get_vecmat_shape(o)) @@ -260,7 +270,7 @@ def jax_warp_abstract(*args, kernel=None): def default_layout(shape): return range(len(shape) - 1, -1, -1) - def warp_call_lowering(ctx, *args, kernel=None): + def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None): if not kernel: raise Exception("Unknown kernel id " + str(kernel)) wp_kernel = _registered_kernels[kernel] @@ -272,12 +282,15 @@ def warp_call_lowering(ctx, *args, kernel=None): if not module.load(device): raise Exception("Could not load kernel on device") - # Infer dimensions from the first input. - warp_arg0 = wp_kernel.adj.args[0] - actual_shape0 = ir.RankedTensorType(args[0].type).shape - dims = strip_vecmat_dimensions(warp_arg0, actual_shape0) - warp_dims = collapse_into_leading_dimension(warp_arg0, dims) - + if launch_dims is None: + # Infer dimensions from the first input. + warp_arg0 = wp_kernel.adj.args[0] + actual_shape0 = ir.RankedTensorType(args[0].type).shape + dims = strip_vecmat_dimensions(warp_arg0, actual_shape0) + warp_dims = collapse_into_leading_dimension(warp_arg0, dims) + else: + dims = launch_dims + warp_dims = launch_dims # Figure out the types and shapes of the input arrays. arg_strings = [] operand_layouts = [] diff --git a/warp/tests/test_jax.py b/warp/tests/test_jax.py index 7cb24db3..d00d03ae 100644 --- a/warp/tests/test_jax.py +++ b/warp/tests/test_jax.py @@ -246,6 +246,60 @@ def f(): assert_np_equal(result_y, expected_y) +@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old") +def test_jax_kernel_launch_dims(test, device): + import jax.numpy as jp + + from warp.jax_experimental import jax_kernel + + n = 64 + m = 32 + + # Test with 1D launch dims + @wp.kernel + def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = x[tid] + 1.0 + + jax_add_one = jax_kernel( + add_one_kernel, launch_dims=(n - 2,) + ) # Intentionally not the same as the first dimension of the input + + @jax.jit + def f_1d(): + x = jp.arange(n, dtype=jp.float32) + return jax_add_one(x) + + # Test with 2D launch dims + @wp.kernel + def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)): + i, j = wp.tid() + y[i, j] = x[i, j] + 1.0 + + jax_add_one_2d = jax_kernel( + add_one_2d_kernel, launch_dims=(n - 2, m - 2) + ) # Intentionally not the same as the first dimension of the input + + @jax.jit + def f_2d(): + x = jp.zeros((n, m), dtype=jp.float32) + 3.0 + return jax_add_one_2d(x) + + # run on the given device + with jax.default_device(wp.device_to_jax(device)): + y_1d = f_1d() + y_2d = f_2d() + + result_1d = np.asarray(y_1d).reshape((n - 2,)) + expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0 + + result_2d = np.asarray(y_2d).reshape((n - 2, m - 2)) + expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32) + + assert_np_equal(result_1d, expected_1d) + assert_np_equal(result_2d, expected_2d) + + class TestJax(unittest.TestCase): pass @@ -296,6 +350,10 @@ class TestJax(unittest.TestCase): TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices ) + add_function_test( + TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices + ) + except Exception as e: print(f"Skipping Jax tests due to exception: {e}")