10
10
from typing import Type
11
11
12
12
import torch
13
- from neuronx_distributed .quantization .quantization_config import QuantizationType , ActivationQuantizationType
13
+ from neuronx_distributed .quantization .quantization_config import (
14
+ ActivationQuantizationType ,
15
+ QuantizationType ,
16
+ )
14
17
from transformers import AutoTokenizer , GenerationConfig
15
18
16
19
from neuronx_distributed_inference .models .application_base import NeuronApplicationBase
28
31
check_accuracy_logits ,
29
32
get_generate_outputs ,
30
33
)
34
+ from neuronx_distributed_inference .utils import argparse_utils
31
35
from neuronx_distributed_inference .utils .benchmark import benchmark_sampling
32
36
from neuronx_distributed_inference .utils .debug_utils import capture_model_inputs
33
37
from neuronx_distributed_inference .utils .distributed import get_init_rank , get_init_world_size
38
+ from neuronx_distributed_inference .utils .exceptions import LogitMatchingValidationError
34
39
from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config
35
40
from neuronx_distributed_inference .utils .random import set_random_seed
36
41
@@ -117,10 +122,12 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
117
122
run_parser .add_argument ("--rpl-reduce-dtype" , type = to_torch_dtype )
118
123
run_parser .add_argument ("--output-logits" , action = "store_true" )
119
124
run_parser .add_argument ("--vocab-parallel" , action = "store_true" )
125
+ run_parser .add_argument ("--layer-boundary-markers" , action = "store_true" , default = False )
120
126
121
127
# Attention
122
128
run_parser .add_argument ("--fused-qkv" , action = "store_true" )
123
129
run_parser .add_argument ("--sequence-parallel-enabled" , action = "store_true" )
130
+ run_parser .add_argument ("--weight-gather-seq-len-threshold" , type = int )
124
131
run_parser .add_argument ("--flash-decoding-enabled" , action = "store_true" )
125
132
126
133
# Continuous batching
@@ -132,6 +139,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
132
139
# KV cache
133
140
run_parser .add_argument ("--kv-cache-batch-size" , type = int )
134
141
run_parser .add_argument ("--kv-cache-padding-size" , type = int )
142
+ run_parser .add_argument ("--disable-kv-cache-tiling" , action = "store_true" )
135
143
136
144
# On device sampling
137
145
run_parser .add_argument ("--on-device-sampling" , action = "store_true" )
@@ -193,9 +201,16 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
193
201
"This is useful for ensuring processes on different nodes are in sync" ,
194
202
)
195
203
run_parser .add_argument (
196
- "--skip-save-sharded-checkpoint" , dest = "save_sharded_checkpoint" , action = "store_false"
204
+ "--save-sharded-checkpoint" ,
205
+ action = "store_true" ,
206
+ help = "Save sharded checkpoints to disk when compiling NxDI model. "
207
+ "When loading NxDI model, sharded checkpoints will be loaded from the compiled model path." ,
208
+ )
209
+ run_parser .add_argument (
210
+ "--skip-sharding" ,
211
+ action = "store_true" ,
212
+ help = "Skip sharding checkpoints when compiling NxDI model. "
197
213
)
198
- run_parser .add_argument ("--skip-sharding" , action = "store_true" )
199
214
200
215
# PA and CF
201
216
run_parser .add_argument (
@@ -206,6 +221,9 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
206
221
run_parser .add_argument (
207
222
"--enable-chunked-prefill" , dest = "is_chunked_prefill" , action = "store_true"
208
223
)
224
+ run_parser .add_argument (
225
+ "--enable-prefix-caching" , dest = "is_prefix_caching" , action = "store_true"
226
+ )
209
227
run_parser .add_argument ("--cp-max-num-seqs" , type = int )
210
228
run_parser .add_argument ("--cp-num-active-blocks" , type = int )
211
229
@@ -214,8 +232,8 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
214
232
215
233
# Lora
216
234
run_parser .add_argument ("--enable-lora" , action = "store_true" )
217
- run_parser .add_argument ("--max-loras" , type = int )
218
- run_parser .add_argument ("--max-lora-rank" , type = int )
235
+ run_parser .add_argument ("--max-loras" , type = int , default = 1 )
236
+ run_parser .add_argument ("--max-lora-rank" , type = int , default = 16 )
219
237
run_parser .add_argument ("--target-modules" , nargs = "+" )
220
238
run_parser .add_argument ("--max-loras-on-cpu" , type = int )
221
239
run_parser .add_argument ("--lora-ckpt-path" , dest = "lora_ckpt_paths" , type = str , action = "append" )
@@ -227,10 +245,21 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
227
245
run_parser .add_argument ("--attn-kernel-enabled" , action = "store_true" )
228
246
run_parser .add_argument ("--mlp-kernel-enabled" , action = "store_true" )
229
247
run_parser .add_argument ("--quantized-mlp-kernel-enabled" , action = "store_true" )
230
- run_parser .add_argument ("--activation-quantization-type" , type = str , choices = [e .value for e in ActivationQuantizationType ])
248
+ run_parser .add_argument ("--fused-rmsnorm-skip-gamma" , action = "store_true" )
249
+ run_parser .add_argument (
250
+ "--activation-quantization-type" ,
251
+ type = str ,
252
+ choices = [e .value for e in ActivationQuantizationType ],
253
+ )
231
254
run_parser .add_argument ("--rmsnorm-quantize-kernel-enabled" , action = "store_true" )
232
- run_parser .add_argument ("--quantize-clamp-bound" , type = float , default = float (' inf' ))
255
+ run_parser .add_argument ("--quantize-clamp-bound" , type = float , default = float (" inf" ))
233
256
run_parser .add_argument ("--mlp-kernel-fuse-residual-add" , action = "store_true" )
257
+ run_parser .add_argument ("--qkv-kernel-fuse-residual-add" , action = "store_true" )
258
+ run_parser .add_argument ("--attn-tkg-nki-kernel-enabled" , action = "store_true" )
259
+ run_parser .add_argument ("--attn-tkg-builtin-kernel-enabled" , action = "store_true" )
260
+ run_parser .add_argument ("--attn-block-tkg-nki-kernel-enabled" , action = "store_true" )
261
+ run_parser .add_argument ("--attn-block-tkg-nki-kernel-cache-update" , action = "store_true" )
262
+ run_parser .add_argument ("--k-cache-transposed" , action = "store_true" )
234
263
235
264
# Logical NeuronCore Configuration (LNC)
236
265
lnc_group = run_parser .add_mutually_exclusive_group ()
@@ -246,7 +275,12 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
246
275
run_parser .add_argument ("--on-cpu" , action = "store_true" )
247
276
248
277
# Debugging
249
- run_parser .add_argument ("--capture-indices" , nargs = "+" , type = int , default = None )
278
+ run_parser .add_argument (
279
+ "--capture-indices" ,
280
+ nargs = "+" ,
281
+ action = argparse_utils .StringOrIntegers ,
282
+ default = None ,
283
+ help = f"Specify '{ argparse_utils .AUTO } ' when using check accuracy mode with { CheckAccuracyMode .LOGIT_MATCHING } for inferrring capture indices when the test fails and use the indices to capture inputs. Otherwise, provide any number of integer values for capturing inputs at those indices." )
250
284
run_parser .add_argument ("--input-capture-save-dir" , type = str , default = None )
251
285
252
286
# Optional demo arguments
@@ -267,6 +301,11 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
267
301
action = "store_true" ,
268
302
help = "Only perform model compilation." ,
269
303
)
304
+ run_parser .add_argument (
305
+ "--compile-dry-run" ,
306
+ action = "store_true" ,
307
+ help = "Perform a compilation dry run (minimal model trace)" ,
308
+ )
270
309
run_parser .add_argument (
271
310
"--hlo-debug" ,
272
311
action = "store_true" ,
@@ -385,10 +424,12 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
385
424
compiling_start_time = time .monotonic ()
386
425
if not args .skip_compile and not args .on_cpu :
387
426
print ("\n Compiling and saving model..." )
388
- model .compile (args .compiled_model_path , debug = args .hlo_debug )
427
+ model .compile (args .compiled_model_path , debug = args .hlo_debug , dry_run = args . compile_dry_run )
389
428
if draft_model is not None and neuron_config .enable_fused_speculation is False :
390
429
print ("\n Compiling and saving draft model..." )
391
- draft_model .compile (args .compiled_draft_model_path )
430
+ draft_model .compile (
431
+ args .compiled_draft_model_path , debug = args .hlo_debug , dry_run = args .compile_dry_run
432
+ )
392
433
compiling_end_time = time .monotonic ()
393
434
total_compiling_time = compiling_end_time - compiling_start_time
394
435
print (f"Compiling and tracing time: { total_compiling_time } seconds" )
@@ -398,7 +439,7 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
398
439
if args .enable_torch_dist :
399
440
torch .distributed .barrier ()
400
441
401
- if args .compile_only :
442
+ if args .compile_only or args . compile_dry_run :
402
443
return
403
444
404
445
# Load compiled model to Neuron.
@@ -446,25 +487,37 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
446
487
if neuron_config .is_medusa :
447
488
draft_model = model
448
489
490
+ input_capture_hook = None
491
+ capture_indices = args .capture_indices
492
+
449
493
# Check accuracy.
450
- run_accuracy_check (
451
- model ,
452
- tokenizer ,
453
- generation_config ,
454
- args .prompts [0 ],
455
- args .check_accuracy_mode ,
456
- args .divergence_difference_tol ,
457
- args .tol_map ,
458
- num_tokens_to_check = args .num_tokens_to_check ,
459
- draft_model = draft_model ,
460
- expected_outputs_path = args .expected_outputs_path ,
461
- )
494
+ logit_error = None
495
+ try :
496
+ run_accuracy_check (
497
+ model ,
498
+ tokenizer ,
499
+ generation_config ,
500
+ args .prompts [0 ],
501
+ args .check_accuracy_mode ,
502
+ args .divergence_difference_tol ,
503
+ args .tol_map ,
504
+ num_tokens_to_check = args .num_tokens_to_check ,
505
+ draft_model = draft_model ,
506
+ expected_outputs_path = args .expected_outputs_path ,
507
+ )
508
+ except LogitMatchingValidationError as e :
509
+ logit_error = e
510
+ if args .capture_indices == argparse_utils .AUTO :
511
+ capture_indices = logit_error .get_divergence_index ()
512
+ print (f"\n Auto capture after a failed logits test. Setting capture indices to { capture_indices } " )
462
513
463
- input_capture_hook = None
464
- if args .capture_indices :
514
+ if args .capture_indices == argparse_utils .AUTO and logit_error is None :
515
+ capture_indices = None
516
+
517
+ if capture_indices is not None :
465
518
input_capture_hook = partial (
466
519
capture_model_inputs ,
467
- capture_indices = args . capture_indices ,
520
+ capture_indices = capture_indices ,
468
521
input_capture_save_dir = args .input_capture_save_dir ,
469
522
)
470
523
@@ -479,6 +532,9 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
479
532
input_capture_hook = input_capture_hook ,
480
533
)
481
534
535
+ if logit_error is not None :
536
+ raise logit_error
537
+
482
538
# Benchmarking.
483
539
if args .benchmark :
484
540
benchmark_sampling (model , draft_model , generation_config )
0 commit comments