Skip to content

Commit

Permalink
addressed remaining review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 committed Dec 4, 2024
1 parent fcc7d2f commit d5a77e3
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions examples/35_gemm_softmax/gemm_softmax_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,17 @@ class GemmSoftmaxAdapter
CUTLASS_TRACE_HOST("GemmUniversal::run()");
dim3 const block = GemmKernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
dim3 const block_finalize = syclcompat::dim3(NumThreadsPerWarp,
std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp,
params.softmax_params.args.M),
1);
dim3 const grid_finalize = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, block_finalize.x),
params.softmax_params.args.batch_count,
1);

// configure smem size and carveout
int smem_size = GemmKernel::SharedStorageSize;
int smem_size_finalize = SoftmaxFinalizeKernel::SharedStorageSize;

Status launch_result{ Status::kSuccess };
// Use extended launch API only for mainloops that use it
Expand All @@ -367,7 +375,9 @@ class GemmSoftmaxAdapter
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
void* kernel_params[] = {&params};
dim3 cluster_finalize(1,1,1);
void* kernel_params[] = {&params.gemm_params};
void* kernel_params_finalize[] = {&params.softmax_params};

if constexpr (kEnableCudaHostAdapter) {
//
Expand All @@ -388,6 +398,13 @@ class GemmSoftmaxAdapter
stream,
kernel_params,
0);
launch_result = cuda_adapter->launch(grid_finalize,
cluster_finalize,
block_finalize,
smem_size_finalize,
stream,
kernel_params_finalize,
0);
}
else {
return Status::kErrorInternal;
Expand All @@ -396,13 +413,17 @@ class GemmSoftmaxAdapter
else {
CUTLASS_ASSERT(cuda_adapter == nullptr);
void const* kernel = (void const*) device_kernel<GemmKernel>;
void const* kernel_finalize = (void const*) device_kernel<SoftmaxFinalizeKernel>;
if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) {
if (is_static_1x1x1 && not launch_with_pdl) {
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params.gemm_params);
device_kernel<SoftmaxFinalizeKernel><<<grid_finalize, block_finalize, smem_size_finalize, stream>>>(params.softmax_params);
}
else {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
launch_result = ClusterLauncher::launch(
grid_finalize, cluster_finalize, block_finalize, smem_size_finalize, stream, kernel_finalize, kernel_params, launch_with_pdl);
}
}
}
Expand All @@ -414,10 +435,14 @@ class GemmSoftmaxAdapter
CUTLASS_ASSERT(cuda_adapter);
if (cuda_adapter) {
void* kernel_params[] = {&params.gemm_params};
void* kernel_params_finalize[] = {&params.softmax_params};

launch_result = cuda_adapter->launch(
grid, block, smem_size, stream, kernel_params, 0
);
launch_result = cuda_adapter->launch(
grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 0
);

}
else {
Expand All @@ -441,19 +466,15 @@ class GemmSoftmaxAdapter
sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)}},
params.gemm_params);
#endif
const auto sycl_block2 = syclcompat::dim3(NumThreadsPerWarp,
std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp,
params.softmax_params.args.M),
1);
const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x),
params.softmax_params.args.batch_count,
1);
const auto sycl_block_finalize = syclcompat::dim3(block_finalize.x, block_finalize.y, block_finalize.z);
const auto sycl_grid_finalize = syclcompat::dim3(grid_finalize.x, grid_finalize.y, grid_finalize.z);
auto event2 = launch<device_kernel<SoftmaxFinalizeKernel>>(launch_policy{
sycl_grid2, sycl_block2, local_mem_size{SoftmaxFinalizeKernel::SharedStorageSize}},
sycl_grid_finalize, sycl_block_finalize, local_mem_size{static_cast<std::size_t>(smem_size_finalize)}},
params.softmax_params);
EventManager::getInstance().addEvent(event2);
#else
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params.gemm_params);
device_kernel<SoftmaxFinalizeKernel><<<grid_finalize, block_finalize, smem_size_finalize, stream>>>(params.softmax_params);
#endif
}
}
Expand Down

0 comments on commit d5a77e3

Please sign in to comment.