Skip to content

Commit 6eef2a6

Browse files
authored
Merge pull request #6 from aws-neuron/release_2.21.1
Neuron 2.21.1 release
2 parents a850e4c + 6d1b8ca commit 6eef2a6

File tree

4 files changed

+82
-8
lines changed

4 files changed

+82
-8
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright Amazon Web Services and its Affiliates. All Rights Reserved.
22
# ==============================================================================
3-
__version__ = "0.1.0"
3+
__version__ = "0.1.1"

src/neuronx_distributed_inference/models/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
Sampler,
3939
prepare_sampling_params,
4040
rand_like,
41+
validate_sampling_params,
4142
)
4243
from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import (
4344
KVCacheManager,
@@ -1358,6 +1359,9 @@ def forward(
13581359
sampling_params = (
13591360
self.default_sampling_params if sampling_params is None else sampling_params
13601361
)
1362+
if self.on_device_sampling:
1363+
validate_sampling_params(sampling_params, self.neuron_config.on_device_sampling_config)
1364+
13611365
self.sampling_params = sampling_params
13621366

13631367
output_attentions, output_hidden_states, return_dict = self._setup_func_config(

src/neuronx_distributed_inference/models/model_wrapper.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,26 +300,36 @@ def get_model_instance(self):
300300

301301
def _forward_with_pad(self, *args):
302302
seq_ids = args[3]
303-
if len(args) > 4:
304-
medusa_args = args[4:8]
303+
sampling_params = args[4]
304+
if len(args) > 5:
305+
medusa_args = args[5:8]
305306
else:
306307
medusa_args = None
307308

308309
# pad the inputs up to the compiled batch size in the end
309-
def pad_helper(tensor):
310+
def pad_helper(tensor, pad_type="zeros"):
311+
VALID_PAD_TYPES = set(["zeros", "ones", "repeat_first_batchline"])
312+
assert (
313+
pad_type in VALID_PAD_TYPES
314+
), f"Found {pad_type=}, but valid pad types are {VALID_PAD_TYPES}"
310315
if tensor is None or tensor.shape[0] == self.neuron_config.batch_size:
311316
return tensor
312317

313318
padded_shape = list(tensor.shape)
314319
padded_shape[0] = self.neuron_config.batch_size
315-
padded_tensor = torch.zeros(padded_shape, dtype=tensor.dtype)
320+
if pad_type == "repeat_first_batchline":
321+
# pad with first batch line values instead of zeros, to reduce chances of NaN
322+
padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype)
323+
else:
324+
fill_value = 0 if pad_type == "zeros" else 1
325+
padded_tensor = torch.full(padded_shape, fill_value=fill_value, dtype=tensor.dtype)
316326
padded_tensor[: tensor.shape[0]] = tensor
317327
return padded_tensor
318328

319329
padded_args = []
320330
# pad input_ids, attn_mask and position_ids
321331
for arg in args[0:3]:
322-
padded_args.append(pad_helper(arg))
332+
padded_args.append(pad_helper(arg, pad_type="repeat_first_batchline"))
323333

324334
# need to handle seq_ids separately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1,
325335
# 0]. then the kv cache of padded input could be written into the first cache line, so we need to pad as [0,
@@ -333,6 +343,10 @@ def pad_helper(tensor):
333343
)
334344
padded_args.append(padded_seq_ids)
335345

346+
# pad sampling params by repeating first batchline
347+
padded_sampling_params = pad_helper(sampling_params, pad_type="repeat_first_batchline")
348+
padded_args.append(padded_sampling_params)
349+
336350
if medusa_args is not None:
337351
for arg in medusa_args:
338352
padded_args.append(pad_helper(arg))

src/neuronx_distributed_inference/modules/generation/sampling.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Union
1+
from typing import Any, Dict, Union
22

33
import torch
44
from neuronx_distributed.operators.argmax import argmax as nxd_argmax
55
from neuronx_distributed.operators.topk import topk as nxd_topk
66
from neuronx_distributed.parallel_layers import parallel_state
77
from torch_neuronx.xla_impl.ops import xla_hlo_call
88

9-
from neuronx_distributed_inference.models.config import NeuronConfig
9+
from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig
1010

1111

1212
@xla_hlo_call
@@ -18,6 +18,62 @@ def rand_like(tensor):
1818
return dtype[shape].Rng(minimum, maximum, distribution=1) # Uniform distribution
1919

2020

21+
def validate_sampling_params(
22+
params: torch.Tensor, on_device_sampling_config: Union[Dict[str, Any], OnDeviceSamplingConfig]
23+
) -> None:
24+
"""
25+
Validates sampling parameters for language models.
26+
27+
Args:
28+
params (torch.Tensor): Tensor of shape (batch_size, 3) containing sampling parameters
29+
in the order: top-k, top-p, temperature.
30+
on_device_sampling_config
31+
32+
Raises:
33+
ValueError: If any of the parameters are invalid.
34+
"""
35+
if params.shape[1] != 3:
36+
raise ValueError(f"Expected tensor of shape (batch_size, 3), but got {params.shape}")
37+
38+
# autocast params tensor to float32
39+
params = params.to(torch.float32)
40+
41+
# Unpack parameters
42+
top_k, top_p, temperature = params[:, 0], params[:, 1], params[:, 2]
43+
44+
if isinstance(on_device_sampling_config, OnDeviceSamplingConfig):
45+
global_top_k = on_device_sampling_config.global_topk
46+
else:
47+
global_top_k = on_device_sampling_config["global_topk"]
48+
49+
# Validate top-k value range
50+
valid_top_k = (top_k == -1) | ((top_k > 0) & (top_k <= global_top_k))
51+
if not torch.all(valid_top_k):
52+
raise ValueError(
53+
f"Invalid top-k values found. top-k must be -1 or greater than 0 but less than or equal to {global_top_k=}. Found {top_k=}."
54+
)
55+
56+
# checks if top-k values can be represented as integers
57+
if not torch.equal(top_k, top_k.floor()):
58+
raise ValueError(
59+
f"Invalid top-k values found. top-k values should be able to be represented as integer values, but found decimal parts. Found {top_k=}."
60+
)
61+
62+
# Validate top-p
63+
valid_top_p = (top_p > 0.0) & (top_p <= 1.0)
64+
if not torch.all(valid_top_p):
65+
raise ValueError(
66+
f"Invalid top-p values found. top-p must be in the range (0.0, 1.0]. Found {top_p=}."
67+
)
68+
69+
# Validate temperature
70+
valid_temp = temperature > 0.0
71+
if not torch.all(valid_temp):
72+
raise ValueError(
73+
f"Invalid temperature values found. Temperature must be strictly greater than 0.0. Found {temperature=}."
74+
)
75+
76+
2177
def prepare_sampling_params(batch_size, top_k=[1], top_p=[1.0], temperature=[1.0]):
2278
top_k = prepare_tensor(top_k)
2379
top_p = prepare_tensor(top_p)

0 commit comments

Comments
 (0)