Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,10 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
if (is_zero_total_L) {
const uint32_t D_start = D_offsets[t] / VEC_WIDTH;
const uint32_t load_D = (D_offsets[t + 1] / VEC_WIDTH) - D_start;
{%- if is_rocm %}
const uint32_t num_warps_per_row = DIV_ROUND_UP(load_D, (kWarpSize/2));
const auto placement = static_cast<PlacementType>(weights_placements[t]);
const auto is_cache = placement == PlacementType::MANAGED_CACHING;
{%- if is_rocm %}
const uint32_t num_warps_per_row = is_cache ? DIV_ROUND_UP(load_D, kWarpSize):DIV_ROUND_UP(load_D, (kWarpSize/2));
{%- else %}
const uint32_t num_warps_per_row = DIV_ROUND_UP(load_D, kWarpSize);
{% endif %}
Expand Down Expand Up @@ -813,15 +815,18 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
load_D = (D_offsets[t + 1] / VEC_WIDTH) - D_start;
}
load_D = shfl_sync(load_D, 0);
{%- if is_rocm %}
const uint32_t num_warps_per_row = DIV_ROUND_UP(load_D, (kWarpSize/2));
const auto placement = static_cast<PlacementType>(weights_placements[t]);
const auto is_cache = placement == PlacementType::MANAGED_CACHING;
{%- if is_rocm %}
const uint32_t num_warps_per_row = is_cache ? DIV_ROUND_UP(load_D, kWarpSize):DIV_ROUND_UP(load_D, (kWarpSize/2));
{%- else %}
const uint32_t num_warps_per_row = DIV_ROUND_UP(load_D, kWarpSize);
{% endif %}
if (table_warp_id >= num_warps_per_row * (is_small_L ? num_warps_for_small_L : B)) {
return;
}


// Compute d (same for all Ls)
const uint32_t load_d = (table_warp_id % num_warps_per_row) * kWarpSize;
// Compute sample ID
Expand Down