diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 5c9e320d5..3562aeead 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -164,6 +164,159 @@ using namespace fbgemm_gpu; {%- endif %} {%- endmacro %} +{#-/* + Splitted version of load_and_accumulate macro. This code chunk describes + the weights load in forward kernel. Set up the WeightRow and load quantization + parameters. Shortcut store for nobag mode. + + The main difference is in whether the slices are loaded from the embedding + table or cache. + + NOTE: The decision was made to define this code chunk as a Jinja macro + instead of inline C++ function, since the compiler might not be able to + inline the code. + + In-code variables that are defined outside: + emb_t, cache_t, cache_t + idx_j + inner_j + D_emb + lxu_cache_weights + {{ locs_or_addrs_idx }}_j + idx_weight_j + VEC_WIDTH + D + kThreadGroupSize + output_j +*/#} +{%- macro load_weights(from_cache) %} + {%- if from_cache %} + const cache_t* cache_weights; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + cache_weights = reinterpret_cast( + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]); + {%- endif %} + {%- endif %} + {#-/* Set the weights row */#} + {%- if is_rocm %} + const auto weights_row = rocm::WeightRowAccessorVec2 + {%- else %} + const auto weights_row = WeightRowAccessor + {%- endif %} + < + emb_t, + cache_t, + cache_t, + {%- if from_cache %} + true + {%- else %} + false + {%- endif %} + >( + {%- if from_cache %} + // Pass nullptr to avoid calling &weights[idx_j * D_emb], which loads + // memory into the registers as a side effect + nullptr, + // Load from the cache + cache_weights, + {%- else %} + // Load from the embedding table + &weights[idx_j * D_emb], + // Pass nullptr bc we are loading from the embedding table + nullptr, + {%- endif %} + D); + + {#-/* Set the quantization params */#} + {%- if from_cache %} + // Assume cache is FP16/FP32, which doesn't require quantization params + const auto qparams = make_float2(0.0f, 0.0f); + {%- else %} + // Load the quantization params from the embedding table row if emb_t == uint8_t + const auto qparams = weights_row.load_qparams(); + {%- endif %} + + {%- if not nobag %} + // Iterate over the row in the weights table, in 4-element strides + #pragma unroll kMaxVecsPerThread + for (int32_t i = 0; i < kMaxVecsPerThread; ++i) + { + // Load the slice of the weights + int32_t d = (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH; + d = (d < D) ? d : 0; + const auto weights_slice = weights_row.load(d, qparams); + vals[inner_j * kMaxVecsPerThread + i] = weights_slice; + } + + {%- else %} + for (int32_t i = 0; i < D; i += kThreadGroupSize * VEC_WIDTH) { + const int32_t d = i + threadIdx.x * VEC_WIDTH; + if (d < D) { + // Since there is no pooling, simply copy the weights to output + const auto weights_slice = weights_row.load(d, qparams); + {%- if is_index_select %} + // output is 1D (because the stride can be irregular) + weights_slice.store(&output[output_offset + output_j * output_stride + d]); + {%- else %} + // output is 2D + weights_slice.store(&output[output_j][d]); + {%- endif %} + } + } + {%- endif %} +{%- endmacro %} + +{#-/* + Splitted version of load_and_accumulate macro. This code chunk + describes the weights accumulate step in the forward kernel. + Accumulate the slices of values from the row. Does nothing for + nobag mode assuming all the work is done in load() macro. + + The main difference is in whether the slices are loaded from the embedding + table or cache. + + NOTE: The decision was made to define this code chunk as a Jinja macro + instead of inline C++ function, since the compiler might not be able to + inline the code. + + In-code variables that are defined outside: + emb_t, cache_t, cache_t + idx_j + inner_j + D_emb + lxu_cache_weights + cache_idx_j + idx_weight_j + VEC_WIDTH + D + kThreadGroupSize + output_j +*/#} +{%- macro accumulate_and_store(from_cache) %} + {%- if not nobag %} + // Iterate over the row in the weights table, in 4-element strides + #pragma unroll kMaxVecsPerThread + for (int32_t i = 0; + i < kMaxVecsPerThread && (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++i) { + {%- if is_gwd_kernel %} + // Scale weights with global weight decay + vals[inner_j * kMaxVecsPerThread + i].mul_(global_weight_decay_j); + {%- endif %} + {%- if weighted %} + // Accumulate the weights * positional weight + accumulators[i].fma_(vals[inner_j * kMaxVecsPerThread + i], idx_weight_j); + {%- else %} + // Accumulate the weights + accumulators[i].add_(vals[inner_j * kMaxVecsPerThread + i]); + {%- endif %} + } + {%- endif %} +{%- endmacro %} + {#-/* This code chunk contains the implementation body of the kernel, and is defined as a Jinja macro to be copy-pasted directly into the kernel as @@ -203,8 +356,162 @@ using namespace fbgemm_gpu; at::acc_type idx_weight = l < L ? indice_weights[indices_start + l] : 0; {%- endif %} + {%- if is_rocm %} + {%- if not nobag %} + rocm::Vec2T vals[kManualUnrollLength * kMaxVecsPerThread]; + {%- endif %} + // Iterate over kThreadGroupSize indices + for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength) + { + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} + // Load index from thread j in the group + [[maybe_unused]] int64_t idx_j_[kManualUnrollLength]; + for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) + { + idx_j_[inner_j] = SHFL_SYNC(idx, outer_j + inner_j); + } + {%- endif %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} + // Load cache's index from thread j in the group + [[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j_[kManualUnrollLength]; + for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) + { + {{ locs_or_addrs_idx }}_j_[inner_j] = use_lxu_cache ? SHFL_SYNC({{ locs_or_addrs_idx }}, outer_j + inner_j) : 0; + } + {%- endif %} + + {%- if weighted %} + // Load positional weight index from thread j in the group + at::acc_type idx_weight_j_[kManualUnrollLength]; + for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) + { + idx_weight_j_[inner_j] = SHFL_SYNC(idx_weight, outer_j + inner_j); + } + {%- endif %} + + + for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) + { + auto j = outer_j + inner_j; + {%- if is_index_select %} + int64_t output_j = L_start + l_start + j; + {%- elif nobag %} + int64_t output_j = indices_start + l_start + j; + {%- endif %} + + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} + [[maybe_unused]] int64_t idx_j = idx_j_[inner_j]; + {%- endif %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} + [[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j + = use_lxu_cache ? {{ locs_or_addrs_idx }}_j_[inner_j] : 0; + + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j = idx_weight_j_[inner_j]; + {%- endif %} + + + + {#/**************************************************************/#} + {#-/* + This is the main switch that determines how we are to load and + accumulate weights, and is determined by Jinja-time, compile-time, + and run-time variables. + */#} + + {%- if dense %} + {#-/* If it's dense, cache is not supported, so load from the embedding table */#} + {{- load_weights(false) }} + + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %} + {#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#} + {{- load_weights(false) }} + + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %} + {#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#} + {{ load_weights(true) }} + {%- else %} + {#-/* Else we defer to run-time selection */#} + if (placement == PlacementType::MANAGED_CACHING + && {{ locs_or_addrs_idx }}_j != kCacheLocationMissing + ) { + {#-/* If the row is available in the cache, fetch from the cache */#} + {{ load_weights(true) }} + } else { + {#-/* Else fetch from the embedding table */#} + {{ load_weights(false) }} + } + + {%- endif %} + {#/**************************************************************/#} + } + {%- if not nobag %} + for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j) + { + auto j = outer_j + inner_j; + + {%- if is_index_select %} + int64_t output_j = L_start + l_start + j; + {%- elif nobag %} + int64_t output_j = indices_start + l_start + j; + {%- endif %} + + {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} + [[maybe_unused]] int64_t idx_j = idx_j_[inner_j]; + {%- endif %} + {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} + [[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j = {{ locs_or_addrs_idx }}_j_[inner_j]; + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j = idx_weight_j_[inner_j]; + {%- endif %} + {%- if is_gwd_kernel %} + const auto global_weight_decay_j = SHFL_SYNC(global_weight_decay, j); + {%- endif %} + + {#/**************************************************************/#} + {#-/* + This is the main switch that determines how we are to load and + accumulate weights, and is determined by Jinja-time, compile-time, + and run-time variables. + */#} + + {%- if dense %} + {#-/* If it's dense, cache is not supported, so load from the embedding table */#} + {{- accumulate_and_store(false) }} + + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::all" %} + {#-/* Else if we know we have a 100% miss rate, then always fetch from the embedding table */#} + {{- accumulate_and_store(false) }} + + {%- elif lxu_miss_rate == "cache_conflict_miss_rate::zero" %} + {#-/* Else if we know we have a 0% miss rate, then always fetch from the cache */#} + {{ accumulate_and_store(true) }} + {%- else %} + {#-/* Else we defer to run-time selection */#} + if (placement == PlacementType::MANAGED_CACHING + && {{ locs_or_addrs_idx }}_j != kCacheLocationMissing) { + {#-/* If the row is available in the cache, fetch from the cache */#} + {{ accumulate_and_store(true) }} + } else { + {#-/* Else fetch from the embedding table */#} + {{ accumulate_and_store(false) }} + } + + {%- endif %} + {#/**************************************************************/#} + } + {%- endif %} + } + {%- endif %} + + {%- if is_rocm %} + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + {%- else %} // Iterate over kThreadGroupSize indices for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group [[maybe_unused]] int64_t idx_j = SHFL_SYNC(idx, j); @@ -370,6 +677,10 @@ batch_index_select_dim0_codegen_forward_kernel( {%- else %} constexpr int VEC_WIDTH = 4; {%- endif %} + {%- if is_rocm %} + // Unroll factor for ROCm devices + constexpr int kManualUnrollLength = 4; + {%- endif %} // Determine the linearized warp ID, and exit early if needed int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;