diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index b06ac8b2b2..d46199c906 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -77,7 +77,6 @@ benchmark_torch_function, benchmark_vbe, fill_random_scale_bias, - warmup, ) else: from fbgemm_gpu.bench.bench_utils import ( @@ -88,28 +87,12 @@ benchmark_torch_function, benchmark_vbe, fill_random_scale_bias, - warmup, ) logging.basicConfig(level=logging.DEBUG) -def kineto_trace_profiler(p: profile, trace_info: tuple[str, str, str, str]) -> float: - phase, trace_url, tbe_type, kern_name = trace_info - p.export_chrome_trace( - trace_url.format(tbe_type=tbe_type, phase=phase, ospid=os.getpid()) - ) - kernel_time = 0 - for event in p.key_averages(): - # Sum the total time of forward kernel runs - if kern_name in event.key: - kernel_time += event.device_time - assert kernel_time > 0 - print(f"Total CUDA time: {kernel_time:.2f} ") - return kernel_time - - @click.group() def cli() -> None: pass @@ -142,7 +125,6 @@ def cli() -> None: @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--dense", is_flag=True, default=False) @click.option("--output-dtype", type=SparseType, default=SparseType.FP32) -@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64") @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) @click.option("--export-trace", is_flag=True, default=False) @@ -162,12 +144,6 @@ def cli() -> None: "--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix" ) @click.option("--cache-load-factor", default=0.2) -@click.option( - "--num-requests", - default=-1, - help="Number of input batches to generate. If the value is smaller than " - "iters, the benchmark will reuse the input batches", -) def device( # noqa C901 alpha: float, bag_size: int, @@ -190,7 +166,6 @@ def device( # noqa C901 flush_gpu_cache_size_mb: int, dense: bool, output_dtype: SparseType, - indices_dtype: str, requests_data_file: Optional[str], tables: Optional[str], export_trace: bool, @@ -199,13 +174,8 @@ def device( # noqa C901 ssd: bool, ssd_prefix: str, cache_load_factor: float, - num_requests: int, ) -> None: assert not ssd or not dense, "--ssd cannot be used together with --dense" - num_requests = iters if num_requests == -1 else num_requests - indices_dtype_torch: torch.dtype = ( - torch.int32 if int(indices_dtype) == 32 else torch.int64 - ) np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -353,8 +323,9 @@ def device( # noqa C901 logging.info( f"Accessed weights per batch: {B * sum(Ds) * L * param_size_multiplier / 1.0e9: .2f} GB" ) + requests = generate_requests( - num_requests, + iters, B, T, L, @@ -365,8 +336,6 @@ def device( # noqa C901 requests_data_file=requests_data_file, tables=tables, use_cpu=not torch.cuda.is_available(), - index_dtype=torch.long, - offset_dtype=torch.long, ) def _kineto_trace_handler(p: profile, phase: str) -> None: @@ -383,14 +352,13 @@ def context_factory(on_trace_ready: Callable[[profile], None]): time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( - indices.to(dtype=indices_dtype_torch), - offsets.to(dtype=indices_dtype_torch), + indices.long(), + offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, num_warmups=warmup_runs, - iters=iters, ) logging.info( @@ -416,8 +384,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]): time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb( - indices.to(dtype=indices_dtype_torch), - offsets.to(dtype=indices_dtype_torch), + indices.long(), + offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ), @@ -425,7 +393,6 @@ def context_factory(on_trace_ready: Callable[[profile], None]): bwd_only=True, grad=grad_output, num_warmups=warmup_runs, - iters=iters, ) logging.info( @@ -603,8 +570,6 @@ def uvm( weighted=weighted, requests_data_file=requests_data_file, tables=tables, - index_dtype=torch.long, - offset_dtype=torch.long, ) requests_gpu = None @@ -620,8 +585,6 @@ def uvm( weighted=False, requests_data_file=requests_data_file, tables=tables, - index_dtype=torch.long, - offset_dtype=torch.long, ) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -691,8 +654,8 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N emb_uvm.local_uvm_cache_stats[4] = 0 if no_conflict_misses else 1 emb_uvm.forward( - indices, - offsets, + indices.long(), + offsets.long(), per_sample_weights, ) @@ -737,8 +700,8 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N time_per_iter = benchmark_requests( requests_gpu, lambda indices, offsets, per_sample_weights: emb_gpu.forward( - indices, - offsets, + indices.long(), + offsets.long(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, @@ -758,8 +721,8 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_mixed.forward( - indices, - offsets, + indices.long(), + offsets.long(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, @@ -904,8 +867,6 @@ def cache( # noqa C901 weighted=weighted, requests_data_file=requests_data_file, tables=tables, - index_dtype=torch.long, - offset_dtype=torch.long, ) warmup_requests, requests = requests[:iters], requests[iters:] grad_output = torch.randn(B, sum(Ds)).cuda() @@ -913,7 +874,7 @@ def cache( # noqa C901 time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_nc( - indices, offsets, per_sample_weights + indices.long(), offsets.long(), per_sample_weights ).backward(grad_output), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, num_warmups=warmup_runs, @@ -927,7 +888,7 @@ def cache( # noqa C901 # warm up for req in warmup_requests: indices, offsets = req.unpack_2() - emb.forward(indices, offsets) + emb.forward(indices.long(), offsets.long()) # get cache miss rate (forward and backward) and exchanged cache lines (prefetch) cache_misses = [] exchanged_cache_lines = [] @@ -938,7 +899,7 @@ def cache( # noqa C901 # Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a # function. old_lxu_cache_state = emb.lxu_cache_state.clone() - emb.prefetch(indices, offsets) + emb.prefetch(indices.long(), offsets.long()) exchanged_cache_lines.append( # pyre-fixme[16]: Item `bool` of `bool | Tensor` has no attribute `sum`. (emb.lxu_cache_state != old_lxu_cache_state) @@ -946,7 +907,7 @@ def cache( # noqa C901 .item() ) cache_misses.append((emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item()) - emb.forward(indices, offsets) + emb.forward(indices.long(), offsets.long()) logging.info( f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines) / len(requests): .2f}, " f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}" @@ -1174,7 +1135,6 @@ def nbit_cpu( # noqa C901 @click.option("--iters", default=100) @click.option("--runs-of-iters", default=5) @click.option("--warmup-runs", default=2) -@click.option("--warmup-ms", type=int, default=None) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--report-aibench", is_flag=True) @click.option("--run-reference", is_flag=True, default=False) @@ -1188,17 +1148,6 @@ def nbit_cpu( # noqa C901 type=str, default="{tbe_type}_tbe_{phase}_trace_{ospid}.json", ) -@click.option( - "--warmup-runs", - default=2, - help="Number of warmup runs. Ignored if --warmup-ms is set.", -) -@click.option( - "--warmup-ms", - type=int, - default=None, - help="Warmup duration in milliseconds. Disables the --run-nums option.", -) def nbit_device( # noqa C901 alpha: float, bag_size: int, @@ -1219,6 +1168,7 @@ def nbit_device( # noqa C901 check_median: bool, iters: int, runs_of_iters: int, + warmup_runs: int, output_dtype: SparseType, report_aibench: bool, run_reference: bool, @@ -1228,8 +1178,6 @@ def nbit_device( # noqa C901 fp8_exponent_bias: Optional[int], export_trace: bool, trace_url: str, - warmup_runs: int, - warmup_ms: Optional[int], ) -> None: np.random.seed(42) torch.manual_seed(42) @@ -1347,7 +1295,6 @@ def nbit_device( # noqa C901 per_sample_weights, ), check_median=check_median, - warmup_ms=warmup_ms, ) # free up GPU memory @@ -1377,6 +1324,18 @@ def nbit_device( # noqa C901 f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" ) + # Get trace for one run of iter + tbe_type: str = "split" + + def _kineto_trace_handler(p: profile, phase: str) -> None: + p.export_chrome_trace( + trace_url.format(tbe_type=tbe_type, phase=phase, ospid=os.getpid()) + ) + + # pyre-ignore[3] + def context_factory(on_trace_ready: Callable[[profile], None]): + return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() + requests = generate_requests( iters, B, @@ -1394,35 +1353,7 @@ def nbit_device( # noqa C901 for req in requests ] - # pyre-ignore[3] - def context_factory(on_trace_ready: Callable[[profile], None]): - return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() - - # Get trace for one run of iter - tbe_type: str = "split" - # input of the kineto_trace_profiler - trace_info = ("fwd", trace_url, tbe_type, "embedding_codegen_forward") - time_dict = {"kernel_time": None} # dict to hold the kernel time - - # warm-up right before profiling - # warmup_ms prioritized over warmup_runs - if warmup_ms or warmup_runs: - warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - warmup_runs, - lambda indices, offsets, per_sample_weights: emb.forward( - indices.int(), - offsets.int(), - per_sample_weights, - ), - ) - - with context_factory( - # pyre-ignore[6] - lambda p: time_dict.update(kernel_time=kineto_trace_profiler(p, trace_info)) - ): + with context_factory(lambda p: _kineto_trace_handler(p, "fwd")): # forward time_per_iter = benchmark_requests( requests, @@ -1433,21 +1364,6 @@ def context_factory(on_trace_ready: Callable[[profile], None]): ), check_median=check_median, ) - - if export_trace: - kernel_time = time_dict["kernel_time"] - # pyre-ignore[58] - bandwidth = read_write_bytes / kernel_time / 1.0e3 - - logging.info( - f"kineto profiled stats: " - f"{weights_precision} Forward, B: {B}, " - f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " - f"BW: {bandwidth: .2f} GB/s, " # noqa: B950 - f"Time: {kernel_time:.0f}us, " - f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" - ) - # free up GPU memory del requests @@ -1549,28 +1465,12 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @click.option("--check-median", is_flag=True, default=True) @click.option("--iters", default=100) @click.option("--runs-of-iters", default=5) +@click.option("--warmup-runs", default=2) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--report-aibench", is_flag=True) @click.option("--fp8-exponent-bits", type=int, default=None) @click.option("--fp8-exponent-bias", type=int, default=None) @click.option("--use-cpu", is_flag=True, default=False) -@click.option("--export-trace", is_flag=True, default=False) -@click.option( - "--trace-url", - type=str, - default="{tbe_type}_tbe_spec_{phase}_trace_{ospid}.json", -) -@click.option( - "--warmup-runs", - default=2, - help="Number of warmup runs. Ignored if --warmup-ms is set.", -) -@click.option( - "--warmup-ms", - type=int, - default=None, - help="Warmup duration in milliseconds. Disables the --run-nums option.", -) def nbit_device_with_spec( # noqa C901 alpha: float, bag_size_list: str, @@ -1590,21 +1490,19 @@ def nbit_device_with_spec( # noqa C901 check_median: bool, iters: int, runs_of_iters: int, + warmup_runs: int, output_dtype: SparseType, report_aibench: bool, fp8_exponent_bits: Optional[int], fp8_exponent_bias: Optional[int], use_cpu: bool, - export_trace: bool, - trace_url: str, - warmup_runs: int, - warmup_ms: Optional[int], ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size Ds = [int(D) for D in embedding_dim_list.split(",")] Ls = [int(L) for L in bag_size_list.split(",")] + # max_Ls = max(Ls) Es = [int(E) for E in num_embeddings_list.split(",")] E = np.mean(Es) D = np.mean(Ds) @@ -1710,7 +1608,6 @@ def nbit_device_with_spec( # noqa C901 ) times = [] - kineto_request = [] for i in range(runs_of_iters): # Generate a request for each table then combine all_requests = { @@ -1787,13 +1684,8 @@ def nbit_device_with_spec( # noqa C901 per_sample_weights, ), check_median=check_median, - warmup_ms=warmup_ms, ) - # copy the request of last iteration for kineto profile benchmark - if i == runs_of_iters - 1: - kineto_request = requests - # free up memory del requests @@ -1821,63 +1713,6 @@ def nbit_device_with_spec( # noqa C901 f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" ) - # pyre-ignore[3] - def context_factory(on_trace_ready: Callable[[profile], None]): - return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() - - if not use_cpu: - # profile with kineto - tbe_type: str = "split" - time_dict = {"kernel_time": None} # Shared variable to hold the kernel time - trace_info = ("fwd", trace_url, tbe_type, "embedding_codegen_forward") - - # warm-up right before profiling - # warmup_ms prioritized over warmup_runs - if warmup_ms or warmup_runs: - warmup( - kineto_request[0], - # pyre-ignore[6] - warmup_ms, - warmup_runs, - lambda indices, offsets, per_sample_weights: emb.forward( - indices.int(), - offsets.int(), - per_sample_weights, - ), - ) - - with context_factory( - # pyre-ignore[6] - lambda p: time_dict.update(kernel_time=kineto_trace_profiler(p, trace_info)) - ): - # forward - time_per_iter = benchmark_requests( - kineto_request, - lambda indices, offsets, per_sample_weights: emb.forward( - indices.int(), - offsets.int(), - per_sample_weights, - ), - check_median=check_median, - ) - - if export_trace: - kernel_time = time_dict["kernel_time"] - # pyre-ignore[6] - bandwidth = read_write_bytes / kernel_time / 1.0e3 - - logging.info( - f"kineto profiled stats: " - f"{weights_precision} Forward, B: {B}, " - f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " - f"BW: {bandwidth: .2f} GB/s, " # noqa: B950 - f"Time: {kernel_time:.0f}us, " - f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" - ) - - # free up memory - del kineto_request - if report_aibench and haveAIBench: print( emitMetric( @@ -2989,8 +2824,6 @@ def bounds_check_indices( # noqa C901 E, requests_data_file=requests_data_file, tables=tables, - index_dtype=torch.long, - offset_dtype=torch.long, ) B_offsets = B_offsets.to(get_device()).to(torch.int) else: @@ -3007,8 +2840,6 @@ def bounds_check_indices( # noqa C901 E, requests_data_file=requests_data_file, tables=tables, - index_dtype=torch.long, - offset_dtype=torch.long, ) warning = torch.tensor([0]).long().to(get_device()) @@ -3062,8 +2893,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]): requests, lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices( rows_per_table, - indices, - offsets, + indices.long(), + offsets.long(), BoundsCheckMode(bounds_check_mode), warning, B_offsets=B_offsets, @@ -3429,8 +3260,6 @@ def device_with_spec( # noqa C901 # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - index_dtype=torch.long, - offset_dtype=torch.long, ) for i, req in enumerate(requests): indices, offsets, weights = req.unpack_3() @@ -3489,8 +3318,8 @@ def device_with_spec( # noqa C901 time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, + indices.long(), + offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ), @@ -3519,8 +3348,8 @@ def device_with_spec( # noqa C901 time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb( - indices, - offsets, + indices.long(), + offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ), @@ -3601,20 +3430,20 @@ def benchmark_tbe_input_compression( compressed_lengths = [L] * sum(compressed_batch_sizes) compressed_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( torch.tensor(compressed_lengths, device=get_device()) - ).long() + ) compressed_values = torch.randint( low=0, high=E, size=(sum(compressed_lengths),), device=get_device(), - dtype=torch.long, + dtype=torch.int32, ) batch_sizes = [B] * T lengths = [L] * sum(batch_sizes) offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( torch.tensor(lengths, device=get_device()) - ).long() + ) reindex = [] for t in range(cT): @@ -3627,11 +3456,7 @@ def benchmark_tbe_input_compression( reindex.extend(range(cB * cT, (cB * cT) + (B * cT))) reindex = torch.tensor(reindex, device=get_device()) - values = ( - torch.index_select(compressed_values.reshape(-1, L), 0, reindex) - .flatten() - .long() - ) + values = torch.index_select(compressed_values.reshape(-1, L), 0, reindex).flatten() requests = [ ( @@ -3652,12 +3477,12 @@ def benchmark_tbe_input_compression( requests, compressed_requests, baseline_func=lambda indices, offsets: emb.forward( - indices, - offsets, + indices.long(), + offsets.long(), ), compressed_func=lambda indices, offsets: emb.forward( - indices, - offsets, + indices.long(), + offsets.long(), batch_size_per_feature_per_rank=[[bs] for bs in compressed_batch_sizes], ), reindex=reindex, @@ -3773,7 +3598,7 @@ def vbe( lengths = torch.cat(lengths_list, 0) # Convert lengths into offsets. - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths).long() + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) # Set up values. values_list: List[torch.Tensor] = [] @@ -3788,7 +3613,7 @@ def vbe( device=get_device(), ) ) - values = torch.cat(values_list, 0).long() + values = torch.cat(values_list, 0) requests = [ ( @@ -3801,8 +3626,8 @@ def vbe( fwd_time_sec, bwd_time_sec = benchmark_vbe( requests, func=lambda indices, offsets: emb.forward( - indices, - offsets, + indices.long(), + offsets.long(), batch_size_per_feature_per_rank=[[B] for B in Bs], ), ) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp index 5fcc3a0176..b03a4f0a4d 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp @@ -200,6 +200,12 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -224,6 +230,12 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -248,6 +260,12 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t row_alignment, @@ -272,6 +290,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -308,6 +332,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices.to(at::kInt), offsets.to(at::kInt), row_alignment ? *row_alignment : 16, @@ -316,7 +346,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1 + ); } if (!indice_weights || indice_weights->numel() == 0) { return int_nbit_split_embedding_codegen_forward_unweighted_cuda( @@ -332,6 +363,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -341,7 +378,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1 + ); } // Force casting indice_weights to float (doing this in the backend to avoid // JIT issue) @@ -359,6 +397,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -369,15 +413,15 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1 + ); } ///@ingroup embedding-cuda -/// Simlar to int_nbit_split_embedding_codegen_lookup_function, but it does /// UVM_CACHING lookup. Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( // First args should be the same to those of - // int_nbit_split_embedding_codegen_lookup_function. + Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, @@ -390,6 +434,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -547,6 +597,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -557,7 +613,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( row_alignment, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias + ); } ///@ingroup embedding-cuda diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 2d2008abea..578e88e96f 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -91,6 +91,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -180,6 +186,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -212,6 +224,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -222,7 +240,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( row_alignment, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias + ); } ///@ingroup embedding-cpu @@ -254,14 +273,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); #endif m.def( - "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor", + "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D ,int INT2_max_ls, int INT4_max_ls, int INT8_max_ls, int FP8_max_ls, int FP16_max_ls, int FP32_max_ls, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1 ) -> Tensor", {PT2_COMPLIANT_TAG}); DISPATCH_TO_CPU( "int_nbit_split_embedding_codegen_lookup_function", int_nbit_split_embedding_codegen_lookup_function_cpu); m.def( - "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); + "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D ,int INT2_max_ls, int INT4_max_ls, int INT8_max_ls, int FP8_max_ls, int FP16_max_ls, int FP32_max_ls, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); DISPATCH_TO_CPU( "int_nbit_split_embedding_uvm_caching_codegen_lookup_function", int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu); diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index c86d5e6e89..d0ee258086 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -10,6 +10,7 @@ {%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor.h" +#include "fbgemm_gpu/config/feature_gates.h" using namespace fbgemm_gpu; using Tensor = at::Tensor; @@ -23,7 +24,7 @@ namespace nbit { same generated source file. */ {%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} -template +template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const pta::PackedTensorAccessor64 dev_weights, @@ -50,6 +51,9 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const int fp8_exponent_bits, const int fp8_exponent_bias, {%- endif %} + const int32_t num_packed_bags, + const int32_t num_packed_bags_D, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -70,9 +74,9 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no #endif // Define {{ emb_weight_type }} kernel invocation macro - #define X(DeviceOnly, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - {{ func_name }}<<< \ - nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ + #define X(DeviceOnly, PackedMode, PackedMode_L, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + {{ func_name }}<<< \ + nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ at::cuda::getCurrentCUDAStream()>>>( \ @@ -86,7 +90,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no {%- else %} D, \ {%- endif %} - FixedDivisor(div_round_up(B, OutputRowsPerThread)), \ + FixedDivisor(div_round_up(B, num_packed_bags * OutputRowsPerThread)), \ MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indices, index_t, 1, 32), \ MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, offsets, index_t, 1, 32), \ {%- if not nobag %} @@ -100,6 +104,9 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no fp8_exponent_bits, \ fp8_exponent_bias, \ {%- endif %} + num_packed_bags, \ + num_packed_bags_D, \ + num_packed_bags_L, \ MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, output, output_t, 2, 32), \ MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_weights, uint8_t, 2, 64), \ MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_locations, int32_t, 1, 32) \ @@ -187,6 +194,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int64_t max_int8_D, const int64_t max_float16_D, const int64_t max_float32_D, + const int64_t INT2_max_ls, + const int64_t INT4_max_ls, + const int64_t INT8_max_ls, + const int64_t FP8_max_ls, + const int64_t FP16_max_ls, + const int64_t FP32_max_ls, Tensor indices, Tensor offsets, {%- if not nobag %} @@ -225,11 +238,43 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; + /* + * Helper macro for run-time packed mode dispatch. Computes maximum number of bags + * (num_packed_bags) that fits into NumUint4LoadsPerRow given embeddings' type and + * size. num_packed_bags is to be used for additional bags indexing + * + * Current support range: ROCm and output_t != uint8_t and sparse_type != FP32 + */ + #define PACKED_MODE_SWITCH(dev_only, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + int32_t num_packed_bags = 1; \ + int32_t num_packed_bags_D = 1; \ + int32_t num_packed_bags_L = 1; \ + const int64_t max_L = max_Ls; \ + {%-if is_rocm and not nobag %} + const static bool use_packed_bag_mode = fbgemm_gpu::config::is_feature_enabled( \ + fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); \ + if(use_packed_bag_mode) { \ + const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \ + constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \ + constexpr int32_t max_indices_per_warp = kWarpSize / NumUint4LoadsPerRow; \ + num_packed_bags_L = max_indices_per_warp > max_L && !std::is_same_v && sparse_type != SparseType::FP32? max_indices_per_warp / max_L : 1; \ + num_packed_bags_D = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v && sparse_type != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \ + num_packed_bags = num_packed_bags_L>1 ? num_packed_bags_D * num_packed_bags_L : num_packed_bags_D; \ + } \ + {%- endif %} + if (num_packed_bags > 1 && num_packed_bags_L>1) { \ + X(dev_only, true, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } else if (num_packed_bags > 1 && num_packed_bags_L<=1) { \ + X(dev_only, true, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } else { \ + X(dev_only, false, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + }; + #define Y(...) \ if (device_only) { \ - X(true, __VA_ARGS__) \ + PACKED_MODE_SWITCH(true, __VA_ARGS__) \ } else { \ - X(false, __VA_ARGS__) \ + PACKED_MODE_SWITCH(false, __VA_ARGS__) \ }; @@ -241,7 +286,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int2_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int2_D > 0) { - auto max_int2_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int2_D, SparseType::INT2, row_alignment), 128); + const auto max_D = max_int2_D; + const auto max_Ls = INT2_max_ls; + constexpr auto sparse_type = SparseType::INT2; + auto max_int2_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int2_128b_rows <= 8); if (max_int2_128b_rows > 0) { Y(2, 16, 0, 1); @@ -268,14 +316,24 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int4_D > 0) { - auto max_int4_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int4_D, SparseType::INT4, row_alignment), 128); + const auto max_D = max_int4_D; + const auto max_Ls = INT4_max_ls; + constexpr auto sparse_type = SparseType::INT4; + auto max_int4_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int4_128b_rows <= 16); + + {%- if is_rocm %} + if (max_int4_128b_rows > 0) { + Y(2, 8, 0, 2); + } + {%- else %} if (max_int4_128b_rows > 0) { Y(4, 8, 0, 1); } if (max_int4_128b_rows > 1) { Y(2, 8, 1, 2); } + {%- endif %} if (max_int4_128b_rows > 2) { Y(1, 4, 2, 4); } @@ -283,7 +341,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Y(1, 4, 4, 8); } if (max_int4_128b_rows > 8) { + {%- if is_rocm %} + Y(1, 2, 8, 16); + {%- else %} Y(1, 4, 8, 16); + {%- endif %} + } } })); @@ -298,14 +361,23 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int8_D > 0) { - auto max_int8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int8_D, SparseType::INT8, row_alignment), 128); + const auto max_D = max_int8_D; + const auto max_Ls = INT8_max_ls; + constexpr auto sparse_type = SparseType::INT8; + auto max_int8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int8_128b_rows <= 32); + {%- if is_rocm %} + if (max_int8_128b_rows > 0) { + Y(2, 4, 0, 2); + } + {%- else %} if (max_int8_128b_rows > 0) { Y(2, 8, 0, 1); } if (max_int8_128b_rows > 1) { Y(2, 4, 1, 2); } + {%- endif %} if (max_int8_128b_rows > 2) { Y(2, 4, 2, 4); } @@ -313,10 +385,18 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Y(2, 4, 4, 8); } if (max_int8_128b_rows > 8) { + {%- if is_rocm %} + Y(1, 2, 8, 16); + {%- else %} Y(2, 2, 8, 16); + {%- endif %} } if (max_int8_128b_rows > 16) { + {%- if is_rocm %} + Y(1, 1, 16, 32); + {%- else %} Y(1, 2, 16, 32); + {%- endif %} } } })); @@ -331,7 +411,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float8_D > 0) { - auto max_fp8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float8_D, SparseType::FP8, row_alignment), 128); + const auto max_D = max_float8_D; + const auto max_Ls = FP8_max_ls; + constexpr auto sparse_type = SparseType::FP8; + auto max_fp8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp8_128b_rows <= 32); if (max_fp8_128b_rows > 0) { Y(2, 8, 0, 1); @@ -364,7 +447,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float16_D > 0) { - auto max_fp16_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float16_D, SparseType::FP16, row_alignment), 128); + const auto max_D = max_float16_D; + const auto max_Ls = FP16_max_ls; + constexpr auto sparse_type = SparseType::FP16; + auto max_fp16_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp16_128b_rows <= 64); if (max_fp16_128b_rows > 0) { Y(2, 8, 0, 2); @@ -397,7 +483,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float32_D > 0) { - auto max_fp32_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float32_D, SparseType::FP32, row_alignment), 128); + const auto max_D = max_float32_D; + const auto max_Ls = FP32_max_ls; + constexpr auto sparse_type = SparseType::FP32; + auto max_fp32_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp32_128b_rows <= 64); // 128 doesn't fit in 48KB SM, so FP32 TBE supports a smaller dimension than others if (max_fp32_128b_rows > 0) { Y(2, 4, 0, 4); @@ -435,6 +524,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int64_t max_int8_D, const int64_t max_float16_D, const int64_t max_float32_D, + const int64_t INT2_max_ls, + const int64_t INT4_max_ls, + const int64_t INT8_max_ls, + const int64_t FP8_max_ls, + const int64_t FP16_max_ls, + const int64_t FP32_max_ls, Tensor indices, Tensor offsets, {%- if not nobag %} @@ -496,6 +591,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, {%- if not nobag %} diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 6be31ab475..04408087fb 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -17,7 +17,7 @@ using Tensor = at::Tensor; namespace nbit { // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) -template +template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const pta::PackedTensorAccessor64 dev_weights, @@ -44,6 +44,9 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no const int exponent_bits, const int exponent_bias, {% endif %} + const int32_t num_packed_bags, + const int32_t num_packed_bags_D, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -77,6 +80,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no if (weight_ty != SparseType::{{ emb_weight_type.enum_name }}) { return; } + bool check_packed_mode = PackedMode; + bool check_packed_mode_L = PackedMode_L; // default to 16 byte alignment for GPU TBE const int32_t D_bytes = padded_row_size_in_bytes(D, weight_ty, row_alignment); @@ -84,247 +89,603 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no if (D_bytes <= MinNum128BRows * 128 || D_bytes > MaxNum128BRows * 128) { return; } + if (PackedMode_L){ + const int64_t weights_offset = weights_offsets[t]; + const int32_t D_total = padded_D(D, weight_ty); + const int32_t D_padding = D_total - D; + + uint32_t warp_idx = threadIdx.y; + int32_t indices_starts[OutputRowsPerThread]; + int32_t Ls[OutputRowsPerThread]; + int32_t max_Ls = 0; + constexpr size_t kOutputsPerThread = {{ (32 // emb_weight_type.bit_width) }}; + const int32_t tot_num_packed_bags = (num_packed_bags_D * num_packed_bags_L); + constexpr uint32_t NumUint4LoadsPerRow = (MaxNum128BRows * 128) / sizeof(uint4); + constexpr int32_t AccumulateStoreRequests = (kWarpSize == 64) ? (MaxNum128BRows + 1) / 2 : AccumulateStoreRequests; + const int32_t bag_size_offset = num_packed_bags_L > 1 ? kWarpSize/(num_packed_bags_L * NumUint4LoadsPerRow) : 1; + const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); + // const int32_t bag_d = kWarpSize/num_packed_bags_L; // num_packed_bags_L can be {1, 2, 4, 8} + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + const uint32_t packed_bag_idx_D = num_packed_bags_D > 1 ? (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row : 0; + const uint32_t packed_bag_idx_L = num_packed_bags_L > 1 ? (threadIdx.x / NumUint4LoadsPerRow) / bag_size_offset : 0; // 63/8/2 = 7/2 = 3 + const uint32_t packed_bag_idx = (packed_bag_idx_L * num_packed_bags_D) + packed_bag_idx_D; + uint32_t b = min(static_cast(bb * tot_num_packed_bags * OutputRowsPerThread + i * tot_num_packed_bags + packed_bag_idx), static_cast(B - 1)); + int32_t indices_start = offsets[t * B + b]; + int32_t indices_end = offsets[t * B + b + 1]; + indices_starts[i] = indices_start; + Ls[i] = indices_end - indices_start; + max_Ls = max(max_Ls, Ls[i]); + } + const index_t* indices_ = &indices[0]; + const uint8_t* __restrict__ weights; + const auto placement = DeviceOnly ? PlacementType::DEVICE : static_cast(weights_placements[t]); + if (placement == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset]; + } else { + weights = &uvm_weights[weights_offset]; + } - const int64_t weights_offset = weights_offsets[t]; - const int32_t D_total = padded_D(D, weight_ty); - const int32_t D_padding = D_total - D; - - uint32_t warp_idx = threadIdx.y; - int32_t indices_starts[OutputRowsPerThread]; - int32_t Ls[OutputRowsPerThread]; - int32_t max_Ls = 0; - - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - indices_starts[i] = indices_start; - Ls[i] = indices_end - indices_start; - max_Ls = max(max_Ls, Ls[i]); - } - const index_t* indices_ = &indices[0]; - - const uint8_t* __restrict__ weights; - const auto placement = DeviceOnly ? PlacementType::DEVICE : static_cast(weights_placements[t]); - if (placement == PlacementType::DEVICE) { - weights = &dev_weights[weights_offset]; - } else { - weights = &uvm_weights[weights_offset]; - } - constexpr size_t kOutputsPerThread = {{ (32 // emb_weight_type.bit_width) }}; - - constexpr uint32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); - const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - - {% if not nobag %} - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][MaxNum128BRows]; - {% endif %} - - for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { - uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - + {% if not nobag %} + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][AccumulateStoreRequests]; + + {% endif %} typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; __shared__ AllBuffers buffers; + for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { + uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); + + + + {% if weighted %} + typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight]; + __shared__ AllIndiceWeights buffers_indice_weights; + {% endif %} + uint32_t input_row_idx = 0; + for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * num_packed_bags_L * NumUint4LoadsPerRow; load_idx += kWarpSize) { + uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow % uint4_loads_per_row; + // To do: write the condition of bag_size_offset depending on num_packed_bags_L + + input_row_idx = num_packed_bags_L>1? (load_idx / NumUint4LoadsPerRow) % bag_size_offset: (load_idx / NumUint4LoadsPerRow); // 63/32 = 1% 1 = 0 + + const uint32_t packed_bag_idx_D = (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row; + const uint32_t packed_bag_idx_L = (threadIdx.x / NumUint4LoadsPerRow) / bag_size_offset; + bool load_idx_valid = packed_bag_idx_D < num_packed_bags_D && packed_bag_idx_L < num_packed_bags_L; + {%- if is_rocm %} + constexpr uint32_t kMaxRowUnroll = 4; + constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; - {% if weighted %} - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight]; - __shared__ AllIndiceWeights buffers_indice_weights; - {% endif %} - - for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * NumUint4LoadsPerRow; load_idx += kWarpSize) { - uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; - uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); - bool load_idx_valid = row_load_idx < uint4_loads_per_row; - - {%- if is_rocm %} - constexpr uint32_t kMaxRowUnroll = 4; - constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; - - #pragma unroll - for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { - uint4 row_data_v[kRowUnroll]; - const uint4* row_v[kRowUnroll]; - int32_t idx_v[kRowUnroll]; - int32_t cache_idx_v[kRowUnroll]; #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - } + for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { + uint4 row_data_v[kRowUnroll]; + const uint4* row_v[kRowUnroll]; + int32_t idx_v[kRowUnroll]; + int32_t cache_idx_v[kRowUnroll]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + valid = valid && (idx_v[inner_i] != -1); + if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { + row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); + } else + if (valid) { + row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + } else { + row_v[inner_i] = reinterpret_cast(&weights[0]); + } + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + } + uint4 zeros = {0, 0, 0, 0}; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); + uint4 data = valid ? row_data_v[inner_i] : zeros; + buffers[warp_idx][i][input_row_idx + bag_size_offset *packed_bag_idx_L][row_load_idx + uint4_loads_per_row * packed_bag_idx_D] = data; + {% if weighted %} + buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% endif %} + } + } + {%- endif %} + + {%- if is_rocm %} + if constexpr (OutputRowsPerThread % kRowUnroll) + { #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; + for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { + {%- else %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + {%- endif %} bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - valid = valid && (idx_v[inner_i] != -1); - if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { - row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); - } else - if (valid) { - row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + valid = valid && (idx != -1); + const uint4* row; + if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { + row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); + } else if (valid) { + row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); } else { - row_v[inner_i] = reinterpret_cast(&weights[0]); + row = reinterpret_cast(&weights[0]); } - } - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - row_data_v[inner_i] = row_v[inner_i][row_load_idx]; - } - uint4 zeros = {0, 0, 0, 0}; - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); - uint4 data = valid ? row_data_v[inner_i] : zeros; - buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx + bag_size_offset * packed_bag_idx_L][row_load_idx + uint4_loads_per_row * packed_bag_idx_D] , &row[row_load_idx], valid); {% if weighted %} buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; {% endif %} } + {%- if is_rocm %} + } // constexpr if (OutputRowsPerThread % kRowUnroll) + {%- endif %} } - {%- endif %} - - {%- if is_rocm %} - if constexpr (OutputRowsPerThread % kRowUnroll) - { - #pragma unroll - for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { - {%- else %} - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - {%- endif %} - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - valid = valid && (idx != -1); - const uint4* row; - if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { - row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); - } else if (valid) { - row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); - } else { - row = reinterpret_cast(&weights[0]); + // equivalent to fence + wait. + cp_async_wait<0>(); + syncwarp(); + const int32_t uints_per_row = 4 * uint4_loads_per_row; + const int32_t packed_bag_idx_D = (threadIdx.x / uints_per_row) % num_packed_bags_D; + // const uint32_t packed_bag_idx_D = num_packed_bags_D > 1 ? (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row : 0; + input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_idx_D * uint4_loads_per_row); + constexpr int32_t max_indices_per_warp = kWarpSize / (MaxNum128BRows * 128 / sizeof(uint4)); + int32_t Ls_shfl[OutputRowsPerThread*max_indices_per_warp]; + for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + Ls_shfl[k*OutputRowsPerThread+i] = shfl_sync(Ls[i], k * bag_size_offset * NumUint4LoadsPerRow + packed_bag_idx_D * uint4_loads_per_row); } - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); + } + for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls_shfl[k*OutputRowsPerThread+i]; + if (!valid) { + continue; + } + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx + bag_size_offset *k][0]); + + // const int32_t packed_bag_idx_D = (threadIdx.x / uints_per_row) % num_packed_bags_D; + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + half2 shift_scale = reinterpret_cast(row)[(packed_bag_idx_D * uints_per_row)]; + {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; + {% endif %} + + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + + {% if not nobag %} + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + {% if weighted %} + accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + + {% endif %} + } + + {% else %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); + } + } + {% endif %} + } + } + + // } + + {% if not nobag %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + const int32_t num_stores_with_padding_per_row = 4 * uint4_loads_per_row; + const int32_t packed_bag_idx_D = threadIdx.x / num_stores_with_padding_per_row; + // for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + uint32_t b = min(static_cast(bb * tot_num_packed_bags * OutputRowsPerThread + i * tot_num_packed_bags + k*num_packed_bags_D + packed_bag_idx_D), static_cast(B - 1)); + const float inv_L = (mean_pooling &&Ls_shfl[k*OutputRowsPerThread+i] != 0) ? static_cast(1.0) / Ls_shfl[k*OutputRowsPerThread+i] : static_cast(1.0); + + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding \ + - packed_bag_idx_D * kOutputsPerThread * num_stores_with_padding_per_row; + accumulators[i][j].mul(inv_L); + + if (output_d >= 0 && output_d < D && packed_bag_idx_D < num_packed_bags_D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + + } + + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + float thread_local_min = std::numeric_limits::max(); + float thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + accumulators[i][j].mul(inv_L); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); + } + } - {% if weighted %} - buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + qparams = warp_find_qparams(thread_local_min, thread_local_max); + const int output_D_start = D_start + t * 8; + const int output_D_end = output_D_start + D; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[b][output_D_end], qparams); + } + } else { + // INT4: not implemented yet + } + } + + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + accumulators[i][j].mul(0.0); // Use a dedicated clear method + } + } + + {% endif %} + } } - {%- if is_rocm %} - } // constexpr if (OutputRowsPerThread % kRowUnroll) - {%- endif %} + } + else { + constexpr int32_t AccumulateStoreRequests = (kWarpSize == 64) ? (MaxNum128BRows + 1) / 2 : MaxNum128BRows; + const int64_t weights_offset = weights_offsets[t]; + const int32_t D_total = padded_D(D, weight_ty); + const int32_t D_padding = D_total - D; + + uint32_t warp_idx = threadIdx.y; + int32_t indices_starts[OutputRowsPerThread]; + int32_t Ls[OutputRowsPerThread]; + int32_t max_Ls = 0; + constexpr size_t kOutputsPerThread = {{ (32 // emb_weight_type.bit_width) }}; + + constexpr uint32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); + const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + const uint32_t packed_bag_idx = PackedMode ? (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row : 0; + uint32_t b = PackedMode ? min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_idx), static_cast(B - 1)) + : min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); + int32_t indices_start = offsets[t * B + b]; + int32_t indices_end = offsets[t * B + b + 1]; + indices_starts[i] = indices_start; + Ls[i] = indices_end - indices_start; + max_Ls = max(max_Ls, Ls[i]); + } + const index_t* indices_ = &indices[0]; + + const uint8_t* __restrict__ weights; + const auto placement = DeviceOnly ? PlacementType::DEVICE : static_cast(weights_placements[t]); + if (placement == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset]; + } else { + weights = &uvm_weights[weights_offset]; } - // equivalent to fence + wait. - cp_async_wait<0>(); - syncwarp(); - for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - if (!valid) { - continue; + + {% if not nobag %} + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][AccumulateStoreRequests]; + {% endif %} + + for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { + uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); + + typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; + __shared__ AllBuffers buffers; + + {% if weighted %} + typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; + __shared__ AllIndiceWeights buffers_indice_weights; + {% endif %} + + for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * NumUint4LoadsPerRow; load_idx += kWarpSize) { + uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; + if constexpr (PackedMode) { + row_load_idx %= uint4_loads_per_row; } - const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - // scale and bias are at the beginning of each row. - // rationale: have scale/shift at start since these get loaded first - // and then broadcasted around so it might speed up the first cache miss. - {% if emb_weight_type.primitive_type == "INT" %} - half2 shift_scale = reinterpret_cast(row)[0]; - {% endif %} + uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); + const uint32_t packed_bag_idx = PackedMode ? (load_idx % NumUint4LoadsPerRow) / uint4_loads_per_row : 0; + bool load_idx_valid = PackedMode ? packed_bag_idx < num_packed_bags : row_load_idx < uint4_loads_per_row; + {%- if is_rocm %} + constexpr uint32_t kMaxRowUnroll = 4; + constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; - {% if weighted %} - float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; - {% endif %} + #pragma unroll + for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { + uint4 row_data_v[kRowUnroll]; + const uint4* row_v[kRowUnroll]; + int32_t idx_v[kRowUnroll]; + int32_t cache_idx_v[kRowUnroll]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + } - using scalar_t = {{ emb_weight_type.cpp_type_name }}; - {% if not nobag %} - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + valid = valid && (idx_v[inner_i] != -1); + if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { + row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); + } else + if (valid) { + row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + } else { + row_v[inner_i] = reinterpret_cast(&weights[0]); + } + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + } + uint4 zeros = {0, 0, 0, 0}; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); + uint4 data = valid ? row_data_v[inner_i] : zeros; + if constexpr (PackedMode) { + buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_idx] = data; + } else { + buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + } + {% if weighted %} + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% endif %} + } + } + {%- endif %} + + {%- if is_rocm %} + if constexpr (OutputRowsPerThread % kRowUnroll) + { + #pragma unroll + for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { + {%- else %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + {%- endif %} + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + valid = valid && (idx != -1); + const uint4* row; + if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { + row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); + } else if (valid) { + row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); + } else { + row = reinterpret_cast(&weights[0]); + } + if constexpr (PackedMode) { + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_idx], &row[row_load_idx], valid); + } else { + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); + } {% if weighted %} - accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); - {% else %} - accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; {% endif %} } - {% else %} - const int32_t output_j = indices_starts[i] + L_start + input_row_idx; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: - // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to - // the scale/shift handling). - // Reason: to avoid divergence the first thread in the warp computes garbage. - const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], num_valid_outputs); - } + {%- if is_rocm %} + } // constexpr if (OutputRowsPerThread % kRowUnroll) + {%- endif %} + } + // equivalent to fence + wait. + cp_async_wait<0>(); + syncwarp(); + const int32_t uints_per_row = 4 * uint4_loads_per_row; + const int32_t packed_bag_idx = PackedMode ? (threadIdx.x / uints_per_row) % num_packed_bags : 0; + if constexpr (PackedMode) { + input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_idx * uint4_loads_per_row); + + #pragma unroll OutputRowsPerThread + for(uint32_t i = 0; i < OutputRowsPerThread; ++i) + { + Ls[i] = shfl_sync(Ls[i], packed_bag_idx * uint4_loads_per_row); + } + } + + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; } - } else if constexpr (std::is_same_v) { - // INT8: - // apply per feature row-wise int8 - auto thread_local_min = std::numeric_limits::max(); - auto thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_idx * uints_per_row : 0]; + {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_idx : 0]; + {% endif %} + + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + + {% if not nobag %} + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); - thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); - } + {% if weighted %} + accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + {% endif %} } - qparams = warp_find_qparams(thread_local_min, thread_local_max); - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + {% else %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); } } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[output_j][D], qparams); - } + {% endif %} } - {% endif %} } } - } {% if not nobag %} #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - const uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast(1.0) / Ls[i]: static_cast(1.0); + const int32_t num_stores_with_padding_per_row = 4 * uint4_loads_per_row; + const int32_t packed_bag_idx = PackedMode ? threadIdx.x / num_stores_with_padding_per_row : 0; + const uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_idx), static_cast(B - 1)); + const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast(1.0) / Ls[i] : static_cast(1.0); if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if constexpr (PackedMode) { + output_d -= packed_bag_idx * kOutputsPerThread * num_stores_with_padding_per_row; + } accumulators[i][j].mul(inv_L); - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + if constexpr (PackedMode) { + if (output_d >= 0 && output_d < D && packed_bag_idx < num_packed_bags) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + } else { + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } } } @@ -334,8 +695,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no float thread_local_min = std::numeric_limits::max(); float thread_local_max = std::numeric_limits::lowest(); float2 qparams; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; accumulators[i][j].mul(inv_L); if (output_d >= 0 && output_d < D) { @@ -347,8 +708,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no qparams = warp_find_qparams(thread_local_min, thread_local_max); const int output_D_start = D_start + t * 8; const int output_D_end = output_D_start + D; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; if (output_d >= 0 && output_d < D) { const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); @@ -364,10 +725,13 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } {% endif %} } +} // kWarpsPerBlock is defined in embedding_forward_quantized_split_nbit_host_template.cu {% set warps_per_block = '4' %} +{% for packed_mode in ['true', 'false'] %} +{% for packed_mode_L in ['true', 'false'] %} {% for device_only in ['true', 'false'] %} {% for output_type in ['at::Half', 'at::BFloat16', 'float', 'uint8_t'] %} {% for index_type in ['int32_t', 'int64_t'] %} @@ -388,7 +752,9 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" {{ params.input_rows_in_flight }}, {{ params.min_128b_rows }}, {{ params.max_128b_rows }}, - {{ device_only }} > ( + {{ device_only }}, + {{ packed_mode }}, + {{ packed_mode_L }} > ( const pta::PackedTensorAccessor64 dev_weights, const pta::PackedTensorAccessor64 uvm_weights, const pta::PackedTensorAccessor32 weights_placements, @@ -413,6 +779,9 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" const int exponent_bits, const int exponent_bias, {% endif %} + const int32_t num_packed_bags, + const int32_t num_packed_bags_D, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32<{{ output_type }}, 2, at::RestrictPtrTraits> output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -426,6 +795,8 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" {% endfor %} // for index_type in ['int32_t', 'int64_t'] {% endfor %} // for output_type in [True, False] {% endfor %} // device_only in [True, False] +{% endfor %} // packed_bags in ['true', 'false'] +{% endfor %} // packed_bags_L in ['true', 'false'] } diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 23e3397d76..1002d55ab7 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -55,7 +55,12 @@ import fbgemm_gpu # noqa - +def find_max_ls(ty: SparseType, weights_tys:List[SparseType], offsets: Tensor )-> int: + bag_sizes = offsets[1:] - offsets[:-1] + for type_ in weights_tys: + if type_ == ty or type_.value == ty.value: + return bag_sizes.max().item() + return 0 def rounded_row_size_in_bytes( dim: int, weight_ty: SparseType, @@ -469,6 +474,8 @@ def max_ty_D(ty: SparseType) -> int: ], default=0, ) + + self.max_int2_D: int = max_ty_D(SparseType.INT2) self.max_int4_D: int = max_ty_D(SparseType.INT4) @@ -476,7 +483,6 @@ def max_ty_D(ty: SparseType) -> int: self.max_float8_D: int = max_ty_D(SparseType.FP8) self.max_float16_D: int = max_ty_D(SparseType.FP16) self.max_float32_D: int = max_ty_D(SparseType.FP32) - self.register_buffer( "D_offsets", torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), @@ -932,6 +938,19 @@ def _forward_impl( indices, offsets, per_sample_weights = inputs_to_device( indices, offsets, per_sample_weights, self.bounds_check_warning.device ) + weights_tys: List[SparseType] = [e[3] for e in self.embedding_specs] + + INT2_max_ls = find_max_ls(SparseType.INT2, weights_tys, offsets) + INT4_max_ls = find_max_ls(SparseType.INT4, weights_tys, offsets) + INT8_max_ls = find_max_ls(SparseType.INT8, weights_tys, offsets) + FP8_max_ls = find_max_ls(SparseType.FP8, weights_tys, offsets) + FP16_max_ls = find_max_ls(SparseType.FP16, weights_tys, offsets) + FP32_max_ls = find_max_ls(SparseType.FP32, weights_tys, offsets) + + + + + # First bound check: check if the indices/offsets are within the boundary # of the original embedding rows before pruning. @@ -1009,6 +1028,12 @@ def _forward_impl( max_int8_D=self.max_int8_D, max_float16_D=self.max_float16_D, max_float32_D=self.max_float32_D, + INT2_max_ls=INT2_max_ls, + INT4_max_ls=INT4_max_ls, + INT8_max_ls=INT8_max_ls, + FP8_max_ls = FP8_max_ls, + FP16_max_ls=FP16_max_ls, + FP32_max_ls=FP32_max_ls, indices=indices, offsets=offsets, pooling_mode=int(self.pooling_mode), @@ -1019,7 +1044,7 @@ def _forward_impl( row_alignment=self.row_alignment, max_float8_D=self.max_float8_D, fp8_exponent_bits=self.fp8_exponent_bits, - fp8_exponent_bias=self.fp8_exponent_bias, + fp8_exponent_bias=self.fp8_exponent_bias ) def forward(