Skip to content

Commit

Permalink
The get_ppl missed the last token of each iteration during multi-it…
Browse files Browse the repository at this point in the history
…er prefill (#2499)

* fix get_ppl

* update

* update

* remove get_ppl from engine.py

* fix according to reviewer comments

* fix

* update

* keep logits.device unchanged

* require input_ids have the same length

* rollback user guide

* update

* split batch dim

* apply torch.cuda.empty_cache()
  • Loading branch information
lvhan028 authored Sep 26, 2024
1 parent bb1dfa6 commit 4812b5a
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 182 deletions.
31 changes: 0 additions & 31 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,34 +1030,3 @@ async def async_end(self, session_id: int):
def end(self, session_id: int):
"""Add new session."""
return self.engine_instance.end(session_id)

def decode(self,
input_ids,
input_embeddings: List[InputEmbeddingType] = None,
input_embedding_ranges: List[InputEmbeddingRangeType] = None,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
adapter_names: List[str] = None):
"""Perform context decode on input tokens.
Args:
input_ids (List[List[int]] | List[np.ndaray]): the batch of input
token ids
steps (List[int]): the offset of the k/v cache
input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
embeddings features
input_embedding_ranges: (List[List[Tuple[int, int]]]):
the begin/end offsets of input_embeddings to input_ids
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
return self.engine_instance.decode(
input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end,
adapter_names=adapter_names)
229 changes: 163 additions & 66 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def get_logits(
for input_id in input_ids:
assert len(input_id) > 0

max_input_len = self.backend_config.max_prefill_token_num
bs = len(input_ids)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)
Expand Down Expand Up @@ -173,79 +178,171 @@ def _split_embeddings(input_ids, niter, iter_len, embeddings,
logits = torch.cat(logits, dim=1)
return logits

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
"""Get perplexity scores given a list of input tokens.
def get_ppl(self, input_ids: Union[List[int],
List[List[int]]]) -> List[float]:
"""Get perplexity scores given a list of input tokens that have to be
of the same length.
Args:
input_ids (Union[List[int], List[List[int]]]): the batch of
input token ids
Returns:
Union[float, List[float]]: A list of perplexity scores.
"""
assert len(input_ids) > 0
assert isinstance(input_ids, List)
if isinstance(input_ids[0], int):
input_ids = [input_ids]
for input_id in input_ids:
assert len(input_id) > 1

max_input_len = self.backend_config.max_prefill_token_num
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)

index_range_starts = []
index_range_ends = []
for input_id in input_ids:
index_range_start = np.array(
[i * max_input_len for i in range(n_max_iter)])
index_range_end = index_range_start + max_input_len
index_range_start[index_range_start >= len(input_id)] = len(
input_id)
index_range_end[index_range_end >= len(input_id)] = len(input_id)
index_range_starts.append(index_range_start)
index_range_ends.append(index_range_end)

generator = self.engine.create_instance()
all_loss_matrix = []
all_target_mask = []
for i in range(n_max_iter):
steps = [start[i] for start in index_range_starts]
_input_ids = [
input_id[start[i]:end[i]] for input_id, start, end in zip(
input_ids, index_range_starts, index_range_ends)
]
_logits = generator.decode(_input_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
_logits = _logits.float().cpu()
padding_token_id = -100
target_ids = [(x + [padding_token_id])[1:] for x in _input_ids]

# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
losses = []
target_counts = []
sorted_index_values = sorted(list(enumerate(sizes)),
key=lambda x: x[1],
reverse=True)
sizes = [value for index, value in sorted_index_values]
indices = [index for index, value in sorted_index_values]
logger.info(f'sorted sizes: {sizes}')
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
if start == end:
loss, target_count = self._get_long_text_ppl(
generator=generator,
input_ids=_input_ids,
max_input_len=max_input_len)
losses.append(loss)
target_counts.append(target_count)
else:
loss, target_count = self._get_ppl(
generator=generator,
input_ids=_input_ids,
max_input_len=max_input_len,
)
losses.append(loss)
target_counts.append(target_count)
loss = torch.concatenate(losses)
target_count = torch.concatenate(target_counts)
loss_avg = loss / target_count
loss_avg = loss_avg.numpy().tolist()
result = list(range(len(loss_avg)))
for index, sorted_index in enumerate(indices):
result[sorted_index] = loss_avg[index]
return result

def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
descend-order list, in which the sum of values in the range is the
maximum number not less than max_value. By "the sum of values",
here it means $$len(sizes[start:end]) * sizes[start]$$
"""
i = 0
while i < len(sizes):
current_sum = 0
start_index = i

while i < len(
sizes) and current_sum + sizes[start_index] <= max_value:
current_sum += sizes[start_index]
i += 1

yield (start_index, i)
if i > start_index:
continue
else:
i += 1

def _get_long_text_ppl(self, generator, input_ids, max_input_len):
assert isinstance(input_ids, List) and len(input_ids) == 1
seq_len = len(input_ids[0])
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[:, i:i + max_input_len]
step = [i]
# shift token_ids by 1 to the left
target_ids = input_ids[:, i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(
generator=generator,
input_ids=token_ids,
max_input_len=max_input_len,
target_ids=target_ids,
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
losses.append(loss)
target_counts.append(target_count)
loss_sum = torch.concatenate(losses).sum().unsqueeze(0)
target_count = torch.concatenate(target_counts).sum().unsqueeze(0)
return loss_sum, target_count

def _get_ppl(self,
generator,
input_ids,
max_input_len,
target_ids=None,
steps=None,
sequence_start: bool = True,
sequence_end: bool = True):
assert isinstance(input_ids, List)
assert all(isinstance(_, List) for _ in input_ids)
if target_ids:
assert all(isinstance(_, List) for _ in target_ids)

lens = [len(_) for _ in input_ids]
total_len = sum(lens)
assert sum(lens) <= max_input_len

logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()
logits = generator.decode(input_ids=input_ids,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end)
bsz, seq_len, vocab_size = logits.shape
logits = logits.float()
padding_token_id = -100
if target_ids is None:
# shift token_ids by 1 to the left
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
target_ids[i] + [padding_token_id]
if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
for i in range(bsz)
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(_logits.device)
target_mask = target_ids != padding_token_id
target_count = torch.sum(target_mask, dim=-1)
# compute cross entropy loss
bsz, seq_len, vocab_size = _logits.shape
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)

all_loss_matrix.append(flat_loss_matrix.view(bsz, seq_len))
all_target_mask.append(target_mask)

all_loss_matrix = torch.cat(all_loss_matrix, dim=1)
all_target_mask = torch.cat(all_target_mask, dim=1)
target_count = torch.sum(all_target_mask, dim=-1)
loss_sum = torch.sum(all_loss_matrix * all_target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()
return loss_avg
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id

# compute cross entropy loss
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len)
loss = flat_loss_matrix.sum(dim=-1).cpu()
target_count = target_mask.sum(dim=-1).cpu()
return loss, target_count
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,9 @@ def weight_type(self):
def group_size(self):
return self.model_config.group_size

@property
def vocab_size(self):
return self.model_config.vocab_size

def __str__(self):
return json.dumps(self.to_dict(), indent=2)
Loading

0 comments on commit 4812b5a

Please sign in to comment.