From 18cc2d25b9f33e3b0470dd4d928afa488ae227ef Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Mon, 2 Dec 2024 10:56:49 -0600 Subject: [PATCH] [ROCm] Implement RNN support --- jax/experimental/rnn.py | 32 ++- jaxlib/gpu/rnn_kernels.cc | 333 ++++++++++++++++-------- jaxlib/gpu/vendor.h | 94 +++++++ jaxlib/gpu_rnn.py | 40 ++- jaxlib/rocm/BUILD | 39 +++ jaxlib/tools/build_gpu_kernels_wheel.py | 3 +- tests/experimental_rnn_test.py | 11 +- 7 files changed, 424 insertions(+), 128 deletions(-) diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 4aa863708189..10904dc9c13b 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -175,6 +175,31 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int, return jax.random.uniform( rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k) +def swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional): + """Swaps the weights for the input and output gates for an LSTM model.""" + weights = jnp.asarray(weights) # Ensure weights are JAX arrays + flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional) + num_directions = 2 if bidirectional else 1 + + w_offsets = 0 + for l in range(num_layers): + for direction in range(num_directions): + # Iterate through all weight and bias gate names to swap gates in both weights and biases + for gate_name in ["W_ih", "W_hh", "b_ih", "b_hh"]: + shape = flat_shapes.pop(0) # Get the current shape and remove it from the list + num_elems = math.prod(shape) + matrix = weights[w_offsets:w_offsets + num_elems].reshape(shape) + + # Swap between the input and output gates (third and fourth gates) + gates = jnp.split(matrix, 4, axis=0) + swapped_matrix = jnp.concatenate([gates[0], gates[1], gates[3], gates[2]], axis=0) + + # Update the weights with swapped matrix + weights = weights.at[w_offsets:w_offsets + num_elems].set(swapped_matrix.flatten()) + w_offsets += num_elems + + return weights + def unpack_lstm_weights( weights: Array, input_size: int, hidden_size: int, num_layers: int, @@ -437,7 +462,9 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw): rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p)) rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: - mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') + mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda') + if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"): + mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm') def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float, @@ -481,5 +508,8 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval, if gpu_rnn: mlir.register_lowering( rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda') + if hasattr(gpu_rnn, "miopen_rnn_bwd_lowering"): + mlir.register_lowering( + rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm') lstm.defvjp(lstm_fwd, lstm_bwd) diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index 27fb8f9c4a06..80e00c27a2a7 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -30,7 +30,7 @@ namespace jax { namespace JAX_GPU_NAMESPACE { std::string ErrorString(gpudnnStatus_t status) { - return cudnnGetErrorString(status); + return gpudnnGetErrorString(status); } template @@ -80,63 +80,88 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - cudnnRNNDescriptor_t rnn_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc))); + gpudnnRNNDescriptor_t rnn_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateRNNDescriptor(&rnn_desc))); - cudnnDropoutDescriptor_t dropout_desc; + gpudnnDropoutDescriptor_t dropout_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc))); + JAX_AS_STATUS(gpudnnCreateDropoutDescriptor(&dropout_desc))); size_t state_size; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor( + JAX_AS_STATUS(gpudnnDropoutGetStatesSize(handle.get(), &state_size))); + +#ifdef JAX_GPU_HIP + void* dropout_states_dev = nullptr; + // Allocate minimal memory for dropout states (can be very small since it's not used) + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMalloc(&dropout_states_dev, state_size))); + if (!dropout_states_dev) { + return absl::InternalError("Failed to allocate minimal GPU memory for dropout states."); + } + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( + dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123, false, false, + MIOPEN_RNG_PSEUDO_XORWOW))); +#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( dropout_desc, handle.get(), dropout, nullptr, state_size, 123))); +#endif // JAX_GPU_HIP // TODO(zhangqiaorjc): Handle other kinds of RNN. - cudnnRNNMode_t cell_mode = CUDNN_LSTM; - cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS; + gpudnnRNNMode_t cell_mode = GPUDNN_LSTM; + gpudnnRNNBiasMode_t bias_mode = GPUDNN_RNN_DOUBLE_BIAS; int num_directions = 1; - cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL; + gpudnnDirectionMode_t dir_mode = GPUDNN_UNIDIRECTIONAL; if (bidirectional) { - dir_mode = CUDNN_BIDIRECTIONAL; + dir_mode = GPUDNN_BIDIRECTIONAL; num_directions = 2; } - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDataType_t data_type = CUDNN_DATA_FLOAT; - cudnnDataType_t math_prec = CUDNN_DATA_FLOAT; - cudnnMathType_t math_type = cudnn_allow_tf32? CUDNN_DEFAULT_MATH: CUDNN_FMA_MATH; + gpudnnRNNInputMode_t input_mode = GPUDNN_LINEAR_INPUT; + gpudnnDataType_t data_type = GPUDNN_DATA_FLOAT; + +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( + rnn_desc, hidden_size, num_layers, dropout_desc, input_mode, dir_mode, + cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type))); +#else // JAX_GPU_CUDA + gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT; + gpudnnMathType_t math_type = cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH; int32_t proj_size = hidden_size; - uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8( - rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, + uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( + rnn_desc, GPUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, input_mode, data_type, math_prec, math_type, input_size, hidden_size, proj_size, num_layers, dropout_desc, aux_flags))); +#endif // JAX_GPU_HIP - cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING; - cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; + gpudnnForwardMode_t fwdMode = GPUDNN_FWD_MODE_TRAINING; + gpudnnRNNDataLayout_t layout = GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; float padding = 0.0f; std::vector seq_length_vector(batch_size, max_seq_length); int32_t* seq_length_array = &seq_length_vector[0]; - cudnnRNNDataDescriptor_t input_data_desc; + gpudnnRNNDataDescriptor_t input_data_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor( + JAX_AS_STATUS(gpudnnCreateRNNDataDescriptor(&input_data_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDataDescriptor( input_data_desc, data_type, layout, max_seq_length, batch_size, input_size, seq_length_array, &padding))); size_t workSpaceSize; size_t reserveSpaceSize; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnGetRNNTempSpaceSizes( +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes( + handle.get(), rnn_desc, input_data_desc, fwdMode, &workSpaceSize, + &reserveSpaceSize))); +#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes( handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize, &reserveSpaceSize))); - +#endif // JAX_GPU_HIP JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc))); + JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc))); + JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(input_data_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); // Round up to nearest multiples of 4 so we can return them as f32 arrays. workSpaceSize += (workSpaceSize % 4); @@ -162,41 +187,61 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - cudnnRNNDescriptor_t rnn_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc))); + gpudnnRNNDescriptor_t rnn_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateRNNDescriptor(&rnn_desc))); - cudnnDropoutDescriptor_t dropout_desc; + gpudnnDropoutDescriptor_t dropout_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc))); + JAX_AS_STATUS(gpudnnCreateDropoutDescriptor(&dropout_desc))); size_t state_size; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor( + JAX_AS_STATUS(gpudnnDropoutGetStatesSize(handle.get(), &state_size))); + + void* dropout_states_dev = nullptr; + // Allocate minimal memory for dropout states (can be very small since it's not used). + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMalloc(&dropout_states_dev, state_size))); + if (!dropout_states_dev) { + return absl::InternalError("Failed to allocate minimal GPU memory for dropout states."); + } + +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( + dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false, + MIOPEN_RNG_PSEUDO_XORWOW))); +#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123))); +#endif // JAX_GPU_HIP // TODO(zhangqiaorjc): Handle other kinds of RNN. - cudnnRNNMode_t cell_mode = CUDNN_LSTM; - cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS; + gpudnnRNNMode_t cell_mode = GPUDNN_LSTM; + gpudnnRNNBiasMode_t bias_mode = GPUDNN_RNN_DOUBLE_BIAS; int num_directions = 1; - cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL; + gpudnnDirectionMode_t dir_mode = GPUDNN_UNIDIRECTIONAL; if (d.bidirectional) { - dir_mode = CUDNN_BIDIRECTIONAL; + dir_mode = GPUDNN_BIDIRECTIONAL; num_directions = 2; } - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDataType_t data_type = CUDNN_DATA_FLOAT; - cudnnDataType_t math_prec = CUDNN_DATA_FLOAT; - cudnnMathType_t math_type = d.cudnn_allow_tf32? CUDNN_DEFAULT_MATH: CUDNN_FMA_MATH; + gpudnnRNNInputMode_t input_mode = GPUDNN_LINEAR_INPUT; + gpudnnDataType_t data_type = GPUDNN_DATA_FLOAT; + +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( + rnn_desc, d.hidden_size, d.num_layers, dropout_desc, input_mode, dir_mode, + cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type))); +#else // JAX_GPU_CUDA + gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT; + gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH; int32_t proj_size = d.hidden_size; - uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED; - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8( - rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, + uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( + rnn_desc, GPUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size, proj_size, d.num_layers, dropout_desc, aux_flags))); +#endif // JAX_GPU_HIP - cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING; - cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; + gpudnnForwardMode_t fwdMode = GPUDNN_FWD_MODE_TRAINING; + gpudnnRNNDataLayout_t layout = GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; float padding = 0.0f; // TODO(zhangqiaorjc): Avoid this cudaMemcpy if possible. @@ -209,17 +254,17 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, gpuMemcpyDeviceToHost, stream))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - cudnnRNNDataDescriptor_t input_data_desc; + gpudnnRNNDataDescriptor_t input_data_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor( + JAX_AS_STATUS(gpudnnCreateRNNDataDescriptor(&input_data_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDataDescriptor( input_data_desc, data_type, layout, d.max_seq_length, d.batch_size, d.input_size, seq_length_array, &padding))); - cudnnRNNDataDescriptor_t output_data_desc; + gpudnnRNNDataDescriptor_t output_data_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor( + JAX_AS_STATUS(gpudnnCreateRNNDataDescriptor(&output_data_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDataDescriptor( output_data_desc, data_type, layout, d.max_seq_length, d.batch_size, d.hidden_size * num_directions, seq_length_array, &padding))); @@ -232,19 +277,31 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, strides[0] = dims[1] * dims[2]; strides[1] = dims[2]; strides[2] = 1; - cudnnTensorDescriptor_t h_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc))); + gpudnnTensorDescriptor_t h_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateTensorDescriptor(&h_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides))); + gpudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides))); - cudnnTensorDescriptor_t c_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc))); + gpudnnTensorDescriptor_t c_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateTensorDescriptor(&c_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides))); + gpudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides))); size_t weight_space_size; +#ifdef JAX_GPU_HIP +miopenTensorDescriptor_t input_tensor_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc))); + int dimsA[2] = {d.batch_size, d.input_size}; + int stridesA[2] = {dimsA[1], 1}; // Row-major order, similar to GPUDNN + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor( + input_tensor_desc, data_type, 2, dimsA, stridesA))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, input_tensor_desc, + &weight_space_size, data_type))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size))); + gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size))); +#endif // JAX_GPU_HIP auto input_buf = buffers[0]; auto h_0_buf = buffers[1]; @@ -255,22 +312,30 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers, auto c_n_buf = buffers[7]; auto workspace_buf = buffers[8]; auto reserve_space_buf = buffers[9]; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNForward( + +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward( + handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf, + h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, + output_data_desc, output_buf, weights_buf, weight_space_size, + workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); +#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward( handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf, input_data_desc, input_buf, output_data_desc, output_buf, h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, weight_space_size, weights_buf, - d.workspace_size, workspace_buf, d.reserve_space_size, - reserve_space_buf))); + workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); +#endif // JAX_GPU_HIP - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(h_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(c_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc))); + JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(input_data_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc))); + JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(output_data_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc))); + JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); return absl::OkStatus(); } @@ -284,40 +349,60 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - cudnnRNNDescriptor_t rnn_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc))); + gpudnnRNNDescriptor_t rnn_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateRNNDescriptor(&rnn_desc))); - cudnnDropoutDescriptor_t dropout_desc; + gpudnnDropoutDescriptor_t dropout_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc))); + JAX_AS_STATUS(gpudnnCreateDropoutDescriptor(&dropout_desc))); size_t state_size; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor( + JAX_AS_STATUS(gpudnnDropoutGetStatesSize(handle.get(), &state_size))); + + void* dropout_states_dev = nullptr; + // Allocate minimal memory for dropout states (can be very small since it's not used) + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMalloc(&dropout_states_dev, state_size))); + if (!dropout_states_dev) { + return absl::InternalError("Failed to allocate minimal GPU memory for dropout states."); + } + +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( + dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false, + MIOPEN_RNG_PSEUDO_XORWOW))); +#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor( dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123))); +#endif // JAX_GPU_HIP // TODO(zhangqiaorjc): Handle other kinds of RNN. - cudnnRNNMode_t cell_mode = CUDNN_LSTM; - cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS; + gpudnnRNNMode_t cell_mode = GPUDNN_LSTM; + gpudnnRNNBiasMode_t bias_mode = GPUDNN_RNN_DOUBLE_BIAS; int num_directions = 1; - cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL; + gpudnnDirectionMode_t dir_mode = GPUDNN_UNIDIRECTIONAL; if (d.bidirectional) { - dir_mode = CUDNN_BIDIRECTIONAL; + dir_mode = GPUDNN_BIDIRECTIONAL; num_directions = 2; } - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDataType_t data_type = CUDNN_DATA_FLOAT; - cudnnDataType_t math_prec = CUDNN_DATA_FLOAT; - cudnnMathType_t math_type = d.cudnn_allow_tf32? CUDNN_DEFAULT_MATH: CUDNN_FMA_MATH; + gpudnnRNNInputMode_t input_mode = GPUDNN_LINEAR_INPUT; + gpudnnDataType_t data_type = GPUDNN_DATA_FLOAT; + +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( + rnn_desc, d.hidden_size, d.num_layers, dropout_desc, input_mode, dir_mode, + cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type))); +#else // JAX_GPU_CUDA + gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT; + gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH; int32_t proj_size = d.hidden_size; - uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED; - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8( - rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, + uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor( + rnn_desc, GPUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode, input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size, proj_size, d.num_layers, dropout_desc, aux_flags))); +#endif // JAX_GPU_HIP - cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; + gpudnnRNNDataLayout_t layout = GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED; float padding = 0.0f; auto seq_lengths_buf = buffers[10]; @@ -329,17 +414,17 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, gpuMemcpyDeviceToHost, stream))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - cudnnRNNDataDescriptor_t input_data_desc; + gpudnnRNNDataDescriptor_t input_data_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor( + JAX_AS_STATUS(gpudnnCreateRNNDataDescriptor(&input_data_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDataDescriptor( input_data_desc, data_type, layout, d.max_seq_length, d.batch_size, d.input_size, seq_length_array, &padding))); - cudnnRNNDataDescriptor_t output_data_desc; + gpudnnRNNDataDescriptor_t output_data_desc; JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor( + JAX_AS_STATUS(gpudnnCreateRNNDataDescriptor(&output_data_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDataDescriptor( output_data_desc, data_type, layout, d.max_seq_length, d.batch_size, d.hidden_size * num_directions, seq_length_array, &padding))); @@ -352,19 +437,31 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, strides[0] = dims[1] * dims[2]; strides[1] = dims[2]; strides[2] = 1; - cudnnTensorDescriptor_t h_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc))); + gpudnnTensorDescriptor_t h_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateTensorDescriptor(&h_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides))); + gpudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides))); - cudnnTensorDescriptor_t c_desc; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc))); + gpudnnTensorDescriptor_t c_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreateTensorDescriptor(&c_desc))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides))); + gpudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides))); size_t weight_space_size; +#ifdef JAX_GPU_HIP + miopenTensorDescriptor_t input_tensor_desc; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc))); + int input_dims[2] = {d.batch_size, d.input_size}; + int input_strides[2] = {input_dims[1], 1}; // row-major order + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor( + input_tensor_desc, data_type, 2, input_dims, input_strides))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, input_tensor_desc, + &weight_space_size, data_type))); +#else // JAX_GPU_CUDA JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size))); + gpudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size))); +#endif // JAX_GPU_HIP auto dy_buf = buffers[0]; auto dh_n_buf = buffers[1]; @@ -384,29 +481,43 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, // auto dw_buf = buffers[14]; auto workspace_buf = buffers[15]; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardData_v8( +#ifdef JAX_GPU_HIP + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData( + handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf, + h_desc, h_0_buf, dh_n_buf, dh_0_buf, + c_desc, c_0_buf, dc_n_buf, dc_0_buf, + input_data_desc, dx_buf, w_buf, weight_space_size, + workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); + + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardWeights( + handle.get(), rnn_desc, input_data_desc, x_buf, h_desc, h_0_buf, + output_data_desc, y_buf, zeroed_dw_buf, weight_space_size, + workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size))); +#else // JAX_GPU_CUDA + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData( handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc, y_buf, dy_buf, input_data_desc, dx_buf, h_desc, h_0_buf, dh_n_buf, dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, weight_space_size, w_buf, d.workspace_size, workspace_buf, d.reserve_space_size, reserve_space_buf))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardWeights_v8( - handle.get(), rnn_desc, CUDNN_WGRAD_MODE_ADD, + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardWeights( + handle.get(), rnn_desc, GPUDNN_WGRAD_MODE_ADD, (const int32_t*)seq_lengths_buf, input_data_desc, x_buf, h_desc, h_0_buf, output_data_desc, y_buf, weight_space_size, zeroed_dw_buf, d.workspace_size, workspace_buf, d.reserve_space_size, reserve_space_buf))); +#endif // JAX_GPU_HIP - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(h_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyTensorDescriptor(c_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc))); + JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(input_data_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc))); + JAX_AS_STATUS(gpudnnDestroyRNNDataDescriptor(output_data_desc))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc))); + JAX_AS_STATUS(gpudnnDestroyDropoutDescriptor(dropout_desc))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnDestroyRNNDescriptor(rnn_desc))); return absl::OkStatus(); } diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 648580f08a92..a74e1b79a75d 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -72,6 +72,18 @@ typedef CUevent gpuEvent_t; typedef CUfunction gpuFunction_t; typedef cudnnHandle_t gpudnnHandle_t; typedef cudnnStatus_t gpudnnStatus_t; +typedef cudnnRNNDescriptor_t gpudnnRNNDescriptor_t; +typedef cudnnDropoutDescriptor_t gpudnnDropoutDescriptor_t; +typedef cudnnTensorDescriptor_t gpudnnTensorDescriptor_t; +typedef cudnnRNNDataDescriptor_t gpudnnRNNDataDescriptor_t; +typedef cudnnRNNDataLayout_t gpudnnRNNDataLayout_t; +typedef cudnnMathType_t gpudnnMathType_t; +typedef cudnnDataType_t gpudnnDataType_t; +typedef cudnnRNNInputMode_t gpudnnRNNInputMode_t; +typedef cudnnDirectionMode_t gpudnnDirectionMode_t; +typedef cudnnRNNBiasMode_t gpudnnRNNBiasMode_t; +typedef cudnnRNNMode_t gpudnnRNNMode_t; +typedef cudnnForwardMode_t gpudnnForwardMode_t; typedef CUmodule gpuModule_t; typedef cusolverDnHandle_t gpusolverDnHandle_t; typedef cusolverStatus_t gpusolverStatus_t; @@ -114,9 +126,41 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS #define gpudnnCreate cudnnCreate +#define gpudnnGetErrorString cudnnGetErrorString +#define gpudnnCreateRNNDescriptor cudnnCreateRNNDescriptor #define gpudnnSetStream cudnnSetStream +#define gpudnnDropoutGetStatesSize cudnnDropoutGetStatesSize +#define gpudnnSetDropoutDescriptor cudnnSetDropoutDescriptor +#define gpudnnDestroyRNNDescriptor cudnnDestroyRNNDescriptor +#define gpudnnDestroyRNNDataDescriptor cudnnDestroyRNNDataDescriptor +#define gpudnnDestroyTensorDescriptor cudnnDestroyTensorDescriptor +#define gpudnnDestroyDropoutDescriptor cudnnDestroyDropoutDescriptor +#define gpudnnRNNBackwardWeights cudnnRNNBackwardWeights_v8 +#define gpudnnRNNBackwardData cudnnRNNBackwardData_v8 +#define gpudnnGetRNNWeightSpaceSize cudnnGetRNNWeightSpaceSize +#define gpudnnCreateTensorDescriptor cudnnCreateTensorDescriptor +#define gpudnnSetTensorNdDescriptor cudnnSetTensorNdDescriptor +#define gpudnnCreateRNNDataDescriptor cudnnCreateRNNDataDescriptor +#define gpudnnSetRNNDataDescriptor cudnnSetRNNDataDescriptor +#define gpudnnSetRNNDescriptor cudnnSetRNNDescriptor_v8 +#define gpudnnCreateDropoutDescriptor cudnnCreateDropoutDescriptor +#define gpudnnGetRNNTempSpaceSizes cudnnGetRNNTempSpaceSizes +#define gpudnnRNNForward cudnnRNNForward #define GPUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS +#define GPUDNN_WGRAD_MODE_ADD CUDNN_WGRAD_MODE_ADD +#define GPUDNN_RNN_ALGO_STANDARD CUDNN_RNN_ALGO_STANDARD +#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED +#define GPUDNN_RNN_PADDED_IO_ENABLED CUDNN_RNN_PADDED_IO_ENABLED +#define GPUDNN_DEFAULT_MATH CUDNN_DEFAULT_MATH +#define GPUDNN_FMA_MATH CUDNN_FMA_MATH +#define GPUDNN_DATA_FLOAT CUDNN_DATA_FLOAT +#define GPUDNN_LINEAR_INPUT CUDNN_LINEAR_INPUT +#define GPUDNN_FWD_MODE_TRAINING CUDNN_FWD_MODE_TRAINING +#define GPUDNN_UNIDIRECTIONAL CUDNN_UNIDIRECTIONAL +#define GPUDNN_RNN_DOUBLE_BIAS CUDNN_RNN_DOUBLE_BIAS +#define GPUDNN_LSTM CUDNN_LSTM +#define GPUDNN_BIDIRECTIONAL CUDNN_BIDIRECTIONAL #define gpusolverDnCreate cusolverDnCreate #define gpusolverDnSetStream cusolverDnSetStream @@ -364,6 +408,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #include "rocm/include/hipblas/hipblas.h" #include "rocm/include/hipsolver/hipsolver.h" #include "rocm/include/hipsparse/hipsparse.h" +#include "rocm/include/miopen/miopen.h" // IWYU pragma: end_exports #define JAX_GPU_NAMESPACE hip @@ -372,6 +417,9 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 +// TODO(Ruturaj4): Currently equivalent API does exist in +// MIOpen lib. Remove when MIOpen support is complete. +#define MIOPEN_STATUS_SUCCESS 0 typedef hipFloatComplex gpuComplex; typedef hipDoubleComplex gpuDoubleComplex; @@ -394,6 +442,19 @@ typedef hipStream_t gpuStream_t; typedef hipError_t gpuError_t; typedef hipEvent_t gpuEvent_t; typedef hipFunction_t gpuFunction_t; +typedef miopenHandle_t gpudnnHandle_t; +typedef miopenStatus_t gpudnnStatus_t; +typedef miopenRNNDescriptor_t gpudnnRNNDescriptor_t; +typedef miopenDropoutDescriptor_t gpudnnDropoutDescriptor_t; +typedef miopenTensorDescriptor_t gpudnnTensorDescriptor_t; +typedef miopenSeqTensorDescriptor_t gpudnnRNNDataDescriptor_t; +typedef miopenRNNBaseLayout_t gpudnnRNNDataLayout_t; +typedef miopenDataType_t gpudnnDataType_t; +typedef miopenRNNInputMode_t gpudnnRNNInputMode_t; +typedef miopenRNNDirectionMode_t gpudnnDirectionMode_t; +typedef miopenRNNBiasMode_t gpudnnRNNBiasMode_t; +typedef miopenRNNMode_t gpudnnRNNMode_t; +typedef miopenRNNFWDMode_t gpudnnForwardMode_t; typedef hipModule_t gpuModule_t; typedef void gpuSyevjInfo; typedef hipsolverSyevjInfo_t gpuSyevjInfo_t; @@ -432,6 +493,39 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define gpudnnCreate miopenCreate +#define gpudnnGetErrorString miopenGetErrorString +#define gpudnnSetStream miopenSetStream +#define gpudnnCreateRNNDescriptor miopenCreateRNNDescriptor +#define gpudnnDropoutGetStatesSize miopenDropoutGetStatesSize +#define gpudnnSetDropoutDescriptor miopenSetDropoutDescriptor +#define gpudnnDestroyRNNDescriptor miopenDestroyRNNDescriptor +#define gpudnnDestroyRNNDataDescriptor miopenDestroySeqTensorDescriptor +#define gpudnnDestroyTensorDescriptor miopenDestroyTensorDescriptor +#define gpudnnDestroyDropoutDescriptor miopenDestroyDropoutDescriptor +#define gpudnnRNNBackwardWeights miopenRNNBackwardWeightsSeqTensor +#define gpudnnCreateRNNDataDescriptor miopenCreateSeqTensorDescriptor +#define gpudnnRNNBackwardData miopenRNNBackwardSeqData +#define gpudnnCreateTensorDescriptor miopenCreateTensorDescriptor +#define gpudnnSetTensorNdDescriptor miopenSetTensorDescriptor +#define gpudnnSetRNNDataDescriptor miopenSetRNNDataSeqTensorDescriptor +#define gpudnnSetRNNDescriptor miopenSetRNNDescriptor_V2 +#define gpudnnCreateDropoutDescriptor miopenCreateDropoutDescriptor +#define gpudnnGetRNNTempSpaceSizes miopenGetRNNTempSpaceSizes +#define gpudnnRNNForward miopenRNNForward +#define gpudnnGetRNNWeightSpaceSize miopenGetRNNParamsSize + +#define GPUDNN_STATUS_SUCCESS MIOPEN_STATUS_SUCCESS +#define GPUDNN_RNN_ALGO_STANDARD miopenRNNdefault +#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED miopenRNNDataSeqMajorPadded +#define GPUDNN_DATA_FLOAT miopenFloat +#define GPUDNN_LINEAR_INPUT miopenRNNlinear +#define GPUDNN_FWD_MODE_TRAINING miopenRNNTraining +#define GPUDNN_UNIDIRECTIONAL miopenRNNunidirection +#define GPUDNN_RNN_DOUBLE_BIAS miopenRNNwithBias +#define GPUDNN_LSTM miopenLSTM +#define GPUDNN_BIDIRECTIONAL miopenRNNbidirection + #define gpusolverDnCreate hipsolverCreate #define gpusolverDnSetStream hipsolverSetStream #define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index 0fc3dc350967..b0aa877df1eb 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import importlib import jaxlib.mlir.ir as ir @@ -24,19 +25,33 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: - _rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib") + _cuda_rnn = importlib.import_module(f"{cuda_module_name}._rnn", package="jaxlib") except ImportError: - _rnn = None + _cuda_rnn = None else: break -if _rnn: - for _name, _value in _rnn.registrations().items(): +if _cuda_rnn: + for _name, _value in _cuda_rnn.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform='CUDA') - compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes + compute_rnn_workspace_reserve_space_sizes = _cuda_rnn.compute_rnn_workspace_reserve_space_sizes -def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hip_rnn = importlib.import_module(f"{rocm_module_name}._rnn", package="jaxlib") + except ImportError: + _hip_rnn = None + else: + break + +if _hip_rnn: + for _name, _value in _hip_rnn.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform='ROCM') + compute_rnn_workspace_reserve_space_sizes = _hip_rnn.compute_rnn_workspace_reserve_space_sizes + + +def _rnn_fwd_lowering(_rnn, platform, ctx, input, h_0, c_0, weights, seq_lengths, *, input_size: int, hidden_size: int, num_layers: int, dropout: bool, bidirectional: bool, cudnn_allow_tf32: bool): @@ -75,11 +90,10 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, reserve_space_shape[0]) i32_type = ir.IntegerType.get_signless(32) - out = hlo.CustomCallOp( [output_type, h_0.type, c_0.type, workspace_type, reserve_space_type], [input, h_0, c_0, weights, seq_lengths], - call_target_name=ir.StringAttr.get('cudnn_rnn'), + call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), @@ -87,6 +101,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, ) return out.results[:-2] + out.results[-1:] # drop workspace output +cudnn_rnn_fwd_lowering = partial(_rnn_fwd_lowering, _cuda_rnn, "cu") +miopen_rnn_fwd_lowering = partial(_rnn_fwd_lowering, _hip_rnn, "hip") + def _hlo_zeros_f32(shape): return hlo.constant( @@ -94,7 +111,7 @@ def _hlo_zeros_f32(shape): np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())) -def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, +def _rnn_bwd_lowering(_rnn, platform, ctx, dy, dhn, dcn, x, h0, c0, w, y, reserve_space, seq_lengths, *, input_size: int, hidden_size: int, num_layers: int, dropout: bool, bidirectional: bool, cudnn_allow_tf32: bool): @@ -123,7 +140,7 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, dy, dhn, dcn, x, h0, c0, w, y, reserve_space, zeroed_dw, seq_lengths ], - call_target_name=ir.StringAttr.get('cudnn_rnn_bwd'), + call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(opaque), api_version=ir.IntegerAttr.get(i32_type, 2), @@ -135,3 +152,6 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, operand_tuple_indices=[]) ])) return out.results[:-1] # drop workspace output + +cudnn_rnn_bwd_lowering = partial(_rnn_bwd_lowering, _cuda_rnn, "cu") +miopen_rnn_bwd_lowering = partial(_rnn_bwd_lowering, _hip_rnn, "hip") diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 26c473201cca..06ed7b291e01 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -135,6 +135,44 @@ nanobind_extension( ], ) +cc_library( + name = "miopen_rnn_kernels", + srcs = ["//jaxlib/gpu:rnn_kernels.cc"], + hdrs = ["//jaxlib/gpu:rnn_kernels.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:handle_pool", + "//jaxlib:kernel_helpers", + "@xla//xla/service:custom_call_status", + "@local_config_rocm//rocm:miopen", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:rocm_headers", + ], +) + +nanobind_extension( + name = "_rnn", + srcs = ["//jaxlib/gpu:rnn.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_rnn", + deps = [ + ":hip_vendor", + ":miopen_rnn_kernels", + "//jaxlib:absl_status_casters", + "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + ], +) + cc_library( name = "hip_solver_handle_pool", srcs = ["//jaxlib/gpu:solver_handle_pool.cc"], @@ -501,6 +539,7 @@ py_library( ":_hybrid", ":_linalg", ":_prng", + ":_rnn", ":_solver", ":_sparse", ":_triton", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 36c1b4d2cbfc..65412f0365dc 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -140,12 +140,13 @@ def prepare_wheel_rocm( copy_runfiles( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_solver.{pyext}", f"__main__/jaxlib/rocm/_blas.{pyext}", f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", + f"__main__/jaxlib/rocm/_solver.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_hybrid.{pyext}", + f"__main__/jaxlib/rocm/_rnn.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index d886a84f914b..34b7500278fe 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -34,14 +34,13 @@ class RnnTest(jtu.JaxTestCase): num_layers=[1, 4], bidirectional=[True, False], ) - @jtu.run_on_devices("cuda") + @jtu.run_on_devices("cuda", "rocm") @jax.default_matmul_precision("float32") def test_lstm(self, batch_size: int, seq_len: int, input_size: int, hidden_size: int, num_layers: int, bidirectional: bool): - # TODO(phawkins): Partially disable this on cudnn version per b/281071013 - if (batch_size == 1 and seq_len == 4 and input_size == 1 and - hidden_size == 6 and num_layers == 4 and bidirectional == False): - self.skipTest("Test requires cudnn >= 8.8") + # TODO(ruturaj4): Bidirectional doesn't quite work well with rocm. + if bidirectional and jtu.is_device_rocm(): + self.skipTest("Bidirectional mode is not available for ROCm.") num_directions = 2 if bidirectional else 1 seq_length_key, root_key = jax.random.split(jax.random.PRNGKey(0)) @@ -61,6 +60,8 @@ def test_lstm(self, batch_size: int, seq_len: int, input_size: int, weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers, bidirectional) def f(weights, x, h_0, c_0): + if jtu.is_device_rocm(): + weights = rnn.swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional) y, h, c = rnn.lstm( x, h_0,