Skip to content

Commit 9b90cd0

Browse files
authored
Merge pull request #18 from aws-neuron/release_2.23.0
Neuron Release 2.23.0
2 parents bceee12 + 5de79de commit 9b90cd0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4202
-1016
lines changed

examples/multi_node.md

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,58 @@
1-
## Example to run two process with tp=16 on single Trn1 node.
1+
## Using NxDI distributed launcher
22

3+
If your environment has ``MPI`` installed, you can make use of distributed laucher utility provided in NxDI library.
4+
5+
Use ``--`` to separate distributed launcher arguments and program command. Specify hostnames/IP addresses through ``--hosts`` argument.
6+
7+
```bash
8+
nxdi_distributed_launcher \
9+
--hosts ip-172-31-35-58 ip-172-31-34-25 \
10+
--nproc_per_node 1 \
11+
-- \
12+
inference_demo \
13+
--model-type llama \
14+
--task-type causal-lm \
15+
run \
16+
--model-path TinyLLama-v0 \
17+
--compiled-model-path traced_models/TinyLLama-v0-multi-node-0/ \
18+
--enable-torch-dist \
19+
--local_ranks_size 32 \
20+
--tp-degree 64 \
21+
--batch-size 2 \
22+
--max-context-length 32 \
23+
--seq-len 64 \
24+
--on-device-sampling \
25+
--enable-bucketing \
26+
--top-k 1 \
27+
--do-sample \
28+
--pad-token-id 2 \
29+
--prompt "I believe the meaning of life is" \
30+
--prompt "The color of the sky is" 2>&1 |& tee log
31+
```
32+
33+
To run ``torchrun`` command with launcher, pass ``--torchrun`` argument to the launcher. E.g.
334

4-
Process 1:
535
```bash
36+
nxdi_distributed_launcher \
37+
--hosts ip-172-31-35-58 ip-172-31-34-25 \
38+
--nproc_per_node 1 \
39+
--torchrun \
40+
-- \
41+
example.py
42+
```
43+
44+
45+
## Manually running processes on each node
46+
47+
User can run one process on each node to execute NxDI model on multiple nodes. In order to distinguish ranks,
48+
please provide ``--start_rank_id <int>`` argument. Please ensure ``--start_rank_id <int>``
49+
is multiple of ``--local_ranks_size <int>`` argument.
650

51+
52+
### Example to run two process with tp=16 on single Trn1 node.
53+
54+
Process 1:
55+
```bash
756
MASTER_PORT=65111 NEURON_RT_VISIBLE_CORES=0-15 NEURON_CPP_LOG_LEVEL=1 NEURON_RT_ROOT_COMM_ID=10.1.201.64:63423 inference_demo \
857
--model-type llama \
958
--task-type causal-lm \
@@ -28,7 +77,7 @@ MASTER_PORT=65111 NEURON_RT_VISIBLE_CORES=0-15 NEURON_CPP_LOG_LEVEL=1 NEURON_R
2877

2978

3079
Process 2:
31-
```
80+
```bash
3281
NEURON_RT_VISIBLE_CORES=16-31 NEURON_CPP_LOG_LEVEL=1 NEURON_RT_ROOT_COMM_ID=10.1.201.64:63423 inference_demo \
3382
--model-type llama \
3483
--task-type causal-lm \
@@ -54,9 +103,9 @@ NEURON_RT_VISIBLE_CORES=16-31 NEURON_CPP_LOG_LEVEL=1 NEURON_RT_ROOT_COMM_ID=10.1
54103
```
55104

56105

57-
## Example to run two process with tp=64 on two Trn1 nodes.
106+
### Example to run two process with tp=64 on two Trn1 nodes.
58107

59-
```
108+
```bash
60109
NEURON_CPP_LOG_LEVEL=1 NEURON_RT_ROOT_COMM_ID=10.1.201.64:63423 inference_demo \
61110
--model-type llama \
62111
--task-type causal-lm \
@@ -79,7 +128,7 @@ NEURON_CPP_LOG_LEVEL=1 NEURON_RT_ROOT_COMM_ID=10.1.201.64:63423 inference_demo \
79128
--prompt "The color of the sky is" 2>&1 | tee rank_0.log
80129
```
81130

82-
```
131+
```bash
83132
NEURON_CPP_LOG_LEVEL=1 NEURON_RT_ROOT_COMM_ID=10.1.201.64:63423 inference_demo \
84133
--model-type llama \
85134
--task-type causal-lm \

setup.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
11
from setuptools import PEP420PackageFinder, setup
2+
import os
3+
import subprocess
4+
from subprocess import CalledProcessError
5+
6+
7+
def get_version(version_str):
8+
major, minor, patch = version_str.split(".")
9+
patch = os.getenv('VERSION_PATCH', patch)
10+
suffix = os.getenv('SUFFIX')
11+
if not suffix:
12+
try:
13+
suffix = f'{subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()[0:8]}.dev'
14+
except CalledProcessError:
15+
suffix = 'dev'
16+
return f"{major}.{minor}.{patch}+{suffix}"
17+
218

319
exec(open("src/neuronx_distributed_inference/_version.py").read())
420
setup(
521
name="neuronx-distributed-inference",
6-
version=__version__, # noqa F821
22+
version=get_version(__version__), # noqa F821
723
classifiers=[
824
"Development Status :: 3 - Alpha",
925
"Intended Audience :: Developers",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright Amazon Web Services and its Affiliates. All Rights Reserved.
22
# ==============================================================================
3-
__version__ = "0.2.0"
3+
__version__ = "0.3.0"

src/neuronx_distributed_inference/inference_demo.py

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from typing import Type
1111

1212
import torch
13-
from neuronx_distributed.quantization.quantization_config import QuantizationType, ActivationQuantizationType
13+
from neuronx_distributed.quantization.quantization_config import (
14+
ActivationQuantizationType,
15+
QuantizationType,
16+
)
1417
from transformers import AutoTokenizer, GenerationConfig
1518

1619
from neuronx_distributed_inference.models.application_base import NeuronApplicationBase
@@ -28,9 +31,11 @@
2831
check_accuracy_logits,
2932
get_generate_outputs,
3033
)
34+
from neuronx_distributed_inference.utils import argparse_utils
3135
from neuronx_distributed_inference.utils.benchmark import benchmark_sampling
3236
from neuronx_distributed_inference.utils.debug_utils import capture_model_inputs
3337
from neuronx_distributed_inference.utils.distributed import get_init_rank, get_init_world_size
38+
from neuronx_distributed_inference.utils.exceptions import LogitMatchingValidationError
3439
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
3540
from neuronx_distributed_inference.utils.random import set_random_seed
3641

@@ -117,10 +122,12 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
117122
run_parser.add_argument("--rpl-reduce-dtype", type=to_torch_dtype)
118123
run_parser.add_argument("--output-logits", action="store_true")
119124
run_parser.add_argument("--vocab-parallel", action="store_true")
125+
run_parser.add_argument("--layer-boundary-markers", action="store_true", default=False)
120126

121127
# Attention
122128
run_parser.add_argument("--fused-qkv", action="store_true")
123129
run_parser.add_argument("--sequence-parallel-enabled", action="store_true")
130+
run_parser.add_argument("--weight-gather-seq-len-threshold", type=int)
124131
run_parser.add_argument("--flash-decoding-enabled", action="store_true")
125132

126133
# Continuous batching
@@ -132,6 +139,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
132139
# KV cache
133140
run_parser.add_argument("--kv-cache-batch-size", type=int)
134141
run_parser.add_argument("--kv-cache-padding-size", type=int)
142+
run_parser.add_argument("--disable-kv-cache-tiling", action="store_true")
135143

136144
# On device sampling
137145
run_parser.add_argument("--on-device-sampling", action="store_true")
@@ -193,9 +201,16 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
193201
"This is useful for ensuring processes on different nodes are in sync",
194202
)
195203
run_parser.add_argument(
196-
"--skip-save-sharded-checkpoint", dest="save_sharded_checkpoint", action="store_false"
204+
"--save-sharded-checkpoint",
205+
action="store_true",
206+
help="Save sharded checkpoints to disk when compiling NxDI model. "
207+
"When loading NxDI model, sharded checkpoints will be loaded from the compiled model path.",
208+
)
209+
run_parser.add_argument(
210+
"--skip-sharding",
211+
action="store_true",
212+
help="Skip sharding checkpoints when compiling NxDI model. "
197213
)
198-
run_parser.add_argument("--skip-sharding", action="store_true")
199214

200215
# PA and CF
201216
run_parser.add_argument(
@@ -206,6 +221,9 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
206221
run_parser.add_argument(
207222
"--enable-chunked-prefill", dest="is_chunked_prefill", action="store_true"
208223
)
224+
run_parser.add_argument(
225+
"--enable-prefix-caching", dest="is_prefix_caching", action="store_true"
226+
)
209227
run_parser.add_argument("--cp-max-num-seqs", type=int)
210228
run_parser.add_argument("--cp-num-active-blocks", type=int)
211229

@@ -214,8 +232,8 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
214232

215233
# Lora
216234
run_parser.add_argument("--enable-lora", action="store_true")
217-
run_parser.add_argument("--max-loras", type=int)
218-
run_parser.add_argument("--max-lora-rank", type=int)
235+
run_parser.add_argument("--max-loras", type=int, default=1)
236+
run_parser.add_argument("--max-lora-rank", type=int, default=16)
219237
run_parser.add_argument("--target-modules", nargs="+")
220238
run_parser.add_argument("--max-loras-on-cpu", type=int)
221239
run_parser.add_argument("--lora-ckpt-path", dest="lora_ckpt_paths", type=str, action="append")
@@ -227,10 +245,21 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
227245
run_parser.add_argument("--attn-kernel-enabled", action="store_true")
228246
run_parser.add_argument("--mlp-kernel-enabled", action="store_true")
229247
run_parser.add_argument("--quantized-mlp-kernel-enabled", action="store_true")
230-
run_parser.add_argument("--activation-quantization-type", type=str, choices=[e.value for e in ActivationQuantizationType])
248+
run_parser.add_argument("--fused-rmsnorm-skip-gamma", action="store_true")
249+
run_parser.add_argument(
250+
"--activation-quantization-type",
251+
type=str,
252+
choices=[e.value for e in ActivationQuantizationType],
253+
)
231254
run_parser.add_argument("--rmsnorm-quantize-kernel-enabled", action="store_true")
232-
run_parser.add_argument("--quantize-clamp-bound", type=float, default=float('inf'))
255+
run_parser.add_argument("--quantize-clamp-bound", type=float, default=float("inf"))
233256
run_parser.add_argument("--mlp-kernel-fuse-residual-add", action="store_true")
257+
run_parser.add_argument("--qkv-kernel-fuse-residual-add", action="store_true")
258+
run_parser.add_argument("--attn-tkg-nki-kernel-enabled", action="store_true")
259+
run_parser.add_argument("--attn-tkg-builtin-kernel-enabled", action="store_true")
260+
run_parser.add_argument("--attn-block-tkg-nki-kernel-enabled", action="store_true")
261+
run_parser.add_argument("--attn-block-tkg-nki-kernel-cache-update", action="store_true")
262+
run_parser.add_argument("--k-cache-transposed", action="store_true")
234263

235264
# Logical NeuronCore Configuration (LNC)
236265
lnc_group = run_parser.add_mutually_exclusive_group()
@@ -246,7 +275,12 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
246275
run_parser.add_argument("--on-cpu", action="store_true")
247276

248277
# Debugging
249-
run_parser.add_argument("--capture-indices", nargs="+", type=int, default=None)
278+
run_parser.add_argument(
279+
"--capture-indices",
280+
nargs="+",
281+
action=argparse_utils.StringOrIntegers,
282+
default=None,
283+
help=f"Specify '{argparse_utils.AUTO}' when using check accuracy mode with {CheckAccuracyMode.LOGIT_MATCHING} for inferrring capture indices when the test fails and use the indices to capture inputs. Otherwise, provide any number of integer values for capturing inputs at those indices.")
250284
run_parser.add_argument("--input-capture-save-dir", type=str, default=None)
251285

252286
# Optional demo arguments
@@ -267,6 +301,11 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
267301
action="store_true",
268302
help="Only perform model compilation.",
269303
)
304+
run_parser.add_argument(
305+
"--compile-dry-run",
306+
action="store_true",
307+
help="Perform a compilation dry run (minimal model trace)",
308+
)
270309
run_parser.add_argument(
271310
"--hlo-debug",
272311
action="store_true",
@@ -385,10 +424,12 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
385424
compiling_start_time = time.monotonic()
386425
if not args.skip_compile and not args.on_cpu:
387426
print("\nCompiling and saving model...")
388-
model.compile(args.compiled_model_path, debug=args.hlo_debug)
427+
model.compile(args.compiled_model_path, debug=args.hlo_debug, dry_run=args.compile_dry_run)
389428
if draft_model is not None and neuron_config.enable_fused_speculation is False:
390429
print("\nCompiling and saving draft model...")
391-
draft_model.compile(args.compiled_draft_model_path)
430+
draft_model.compile(
431+
args.compiled_draft_model_path, debug=args.hlo_debug, dry_run=args.compile_dry_run
432+
)
392433
compiling_end_time = time.monotonic()
393434
total_compiling_time = compiling_end_time - compiling_start_time
394435
print(f"Compiling and tracing time: {total_compiling_time} seconds")
@@ -398,7 +439,7 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
398439
if args.enable_torch_dist:
399440
torch.distributed.barrier()
400441

401-
if args.compile_only:
442+
if args.compile_only or args.compile_dry_run:
402443
return
403444

404445
# Load compiled model to Neuron.
@@ -446,25 +487,37 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
446487
if neuron_config.is_medusa:
447488
draft_model = model
448489

490+
input_capture_hook = None
491+
capture_indices = args.capture_indices
492+
449493
# Check accuracy.
450-
run_accuracy_check(
451-
model,
452-
tokenizer,
453-
generation_config,
454-
args.prompts[0],
455-
args.check_accuracy_mode,
456-
args.divergence_difference_tol,
457-
args.tol_map,
458-
num_tokens_to_check=args.num_tokens_to_check,
459-
draft_model=draft_model,
460-
expected_outputs_path=args.expected_outputs_path,
461-
)
494+
logit_error = None
495+
try:
496+
run_accuracy_check(
497+
model,
498+
tokenizer,
499+
generation_config,
500+
args.prompts[0],
501+
args.check_accuracy_mode,
502+
args.divergence_difference_tol,
503+
args.tol_map,
504+
num_tokens_to_check=args.num_tokens_to_check,
505+
draft_model=draft_model,
506+
expected_outputs_path=args.expected_outputs_path,
507+
)
508+
except LogitMatchingValidationError as e:
509+
logit_error = e
510+
if args.capture_indices == argparse_utils.AUTO:
511+
capture_indices = logit_error.get_divergence_index()
512+
print(f"\nAuto capture after a failed logits test. Setting capture indices to {capture_indices}")
462513

463-
input_capture_hook = None
464-
if args.capture_indices:
514+
if args.capture_indices == argparse_utils.AUTO and logit_error is None:
515+
capture_indices = None
516+
517+
if capture_indices is not None:
465518
input_capture_hook = partial(
466519
capture_model_inputs,
467-
capture_indices=args.capture_indices,
520+
capture_indices=capture_indices,
468521
input_capture_save_dir=args.input_capture_save_dir,
469522
)
470523

@@ -479,6 +532,9 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
479532
input_capture_hook=input_capture_hook,
480533
)
481534

535+
if logit_error is not None:
536+
raise logit_error
537+
482538
# Benchmarking.
483539
if args.benchmark:
484540
benchmark_sampling(model, draft_model, generation_config)

0 commit comments

Comments
 (0)