Skip to content

Commit

Permalink
fix(generation): correct generation for batch size > 1
Browse files Browse the repository at this point in the history
The sampling function passed to the Jetstream's generate method is
JIT'ed. This made the generations incorrect, because it was not
detecting that slots were modified, thus producing incorrect results for
batch size >1.
The solution is to do a `jax.pure_callback` the will go back to python
and to check the objects values.
At that point another issue appeared: the sampling functions did not
work in this context, getting stuck forever on the call.
The workaround is to force JIT on this last part.
  • Loading branch information
tengomucho committed Nov 4, 2024
1 parent 8d1c3df commit 542a9bc
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,13 @@ def set(self, slot: Slot):
self._curslot = slot

def select(self, logits: jnp.ndarray) -> int:
return self._curslot.select(logits)
def _inner_select(logits):
return self._curslot.select(logits)
token = jax.pure_callback(
_inner_select,
result_shape_dtypes=jax.ShapeDtypeStruct((), jnp.int32),
logits=logits)
return token

class TpuGeneratorJetStream(Generator):
"""A Generator for models running on TPU, single threaded."""
Expand All @@ -266,9 +272,7 @@ def __init__(
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids

# Slots are empty to begin with, they will be populated as new batches arrive
self.slots = []
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)]
self.batch_id = 0
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
Expand Down Expand Up @@ -366,13 +370,11 @@ def warmup(self, batch: Batch) -> int:
seq_len = self.engine.env.seq_len
return batch_size * seq_len

def _get_slot_id(self):
"""Get the next available slot id."""
batch_size = self.engine.env.batch_size
used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY]
for i in range(batch_size):
if i not in used_ids:
return i
def _get_slot(self):
"""Get the next available slot."""
for slot in self.slots:
if slot.state == Slot.State.EMPTY:
return slot
# if we reach this point, all slots were used - this should not happen
raise ValueError("All slots are used, but we should have stopped earlier")

Expand Down Expand Up @@ -422,14 +424,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""

slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
len_active_slots = len(slots[Slot.State.READY])
# Delete all empty slots, no need to have them anymore
empty_slots = slots[Slot.State.EMPTY]
for slot in empty_slots:
self.slots.remove(slot)
active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY]
len_active_slots = len(active_slots)

len_requests = len(batch.requests)
model_batch_size = self.model.config.batch_size
if model_batch_size is not None and model_batch_size < len_active_slots + len_requests:
Expand All @@ -444,10 +441,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# Assign each request to an empty slot
logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)")
generations = []

prefilled_active_slots = []
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self._get_slot_id(), self.tokenizer)
slot = self._get_slot()
self.prefill_slot.set(slot)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
Expand Down Expand Up @@ -479,29 +476,34 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:

self._post_generate(slot, next_token, generations)
if not slot.empty:
# append current to list of active slots
self.slots.append(slot)
len_active_slots += 1

batch = None
if len_active_slots > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(self.batch_id, request_ids)
else:
logger.debug("No more pending requests")
prefilled_active_slots.append(slot)

cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots)
self.batch_id += 1
logger.debug("Model ready for decoding")
return generations, batch
return generations, cached_batch

def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray:
pad_token_id = self.tokenizer.pad_token_id
batch_size = logits.shape[0]
tokens = jnp.full((batch_size, 1), pad_token_id)
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots):
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits)
tokens = tokens.at[slot.id].set(next_token)

def inner_select_from_slots(logits, batch_size):
tokens = jnp.full((batch_size, 1), pad_token_id)
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots):
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits[slot.id : slot.id + 1, :])
tokens = tokens.at[slot.id].set(next_token)
return tokens

# NOTE: The above code has been written in a non-functional way, thus causing problems when the function is
# JIT'ed. The workaround is to use jax.pure_callback, but this might cause performance issues: if those are
# observed, the code should be modified accordingly.
tokens = jax.pure_callback(
inner_select_from_slots,
result_shape_dtypes=jax.ShapeDtypeStruct((batch_size, 1), jnp.int32),
logits=logits,
batch_size=batch_size,
)
return tokens

def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
Expand All @@ -528,16 +530,14 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
if len(batches) > 1:
logger.warning("Unexpected multiple batches received, only the first one will be processed.")
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
self.slots.remove(slot)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.")
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
Expand All @@ -551,7 +551,6 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
select_fn = jax.tree_util.Partial(self._select_from_slots)
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn)

newly_empty = []
generations = []
for slot in active_slots:
# Get the next token.
Expand All @@ -564,20 +563,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
raise ValueError("Unexpected Slot is not ready for decoding")

self._post_generate(slot, next_token, generations)
if slot.empty:
newly_empty.append(slot)

# Remove empty slots
for slot in newly_empty:
self.slots.remove(slot)
batch = None
if len(self.slots) > 0:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY]
batch = self._cached_batch(next_batch_id, request_ids)
else:
logger.debug("No more pending requests")
return generations, batch

cached_batch = self._cached_batch(next_batch_id, active_slots)
return generations, cached_batch

def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None:
"""Post-generate a slot after the generation has been completed.
Expand Down Expand Up @@ -625,7 +613,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati
)
)

def _cached_batch(self, batch_id: int, request_ids: List):
def _cached_batch(self, batch_id: int, active_slots: List):
"""Create a CachedBatch from the active slots.
"""
request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY]
if len(request_ids) == 0:
logger.debug("No more pending requests")
return None
size = len(request_ids)
max_tokens = size * self.model.config.sequence_length
return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
# Seed needs to fit a 64-bit integer, so we modulo it in case is bigger (that can happen!)
seed = seed % jnp.iinfo(jnp.int64).max
self.key = jax.random.PRNGKey(seed)
# TODO: it seems the sample method needs to be JIT'ed, otherwise for some reason it seems to get stuck.
self.sample = jax.jit(self._sample)

@classmethod
def create(
Expand Down Expand Up @@ -175,7 +177,7 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
"""
scores = self.logits_processor(input_ids, logits)
if self.mode == GenerationMode.SAMPLE:
return self._sample(scores)
return self.sample(scores)
else:
return jnp.argmax(scores, axis=-1)

Expand Down
141 changes: 141 additions & 0 deletions text-generation-inference/tests/test_tinyllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@

import pytest
from helpers import create_request, prepare_model
from text_generation_server.auto_generator import AutoGenerator
from text_generation_server.pb.generate_pb2 import Batch
from tqdm import tqdm


MODEL_ID = "Maykeye/TinyLLama-v0"
SEQUENCE_LENGTH = 256


@pytest.fixture(scope="module")
def model_path():
return prepare_model(MODEL_ID, SEQUENCE_LENGTH)


def test_info(model_path):
generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1)
info = generator.info
assert info.requires_padding is True
assert info.device_type == "meta"
assert info.window_size == 0
assert info.speculate == 0


@pytest.mark.parametrize(
"input_text, token_id, token_text, do_sample",
[
[
"It was a bright cold day in April, and the clocks were striking thirteen.",
347,
" The",
False,
],
[
"It was a bright cold day in April, and the clocks were striking thirteen.",
13,
"\n",
True,
],
],
ids=["greedy", "sample"],
)
@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"])
def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path):
generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH)
requests = []
max_new_tokens = 20
for i in range(batch_size):
requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens))
# Let's be pessimistic when estimating max_tokens
batch_size * (len(input_text) + max_new_tokens)
batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == batch_size
# Whatever was passed as max_tokens, the server will correct it
# because of static batching
assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH
assert len(generations) == batch_size
for g in generations:
tokens = g.tokens
assert tokens.ids == [token_id]
assert tokens.texts == [token_text]
# Redo but with greedy
batch.requests[0].parameters.do_sample = False
generator.clear()
generations, next_batch = generator.prefill(batch)
print(generations[0])
batch.requests[0].parameters.do_sample = True
generator.clear()
generations, next_batch = generator.prefill(batch)
print(generations[0])


def test_decode_multiple(model_path):
generator = AutoGenerator.from_pretrained(model_path,
revision="",
max_batch_size=2,
max_sequence_length=SEQUENCE_LENGTH)
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == 1
# Decode a few tokens
gen_tokens = 4
for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert next_batch.size == 1
# Add a second request
request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch_1 = generator.prefill(batch)
assert next_batch_1.size == 1
# We should have generated only a single token
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert len(tokens[1]) == 1
# Decode more tokens until we reach the maximum for the first request
batches = [next_batch, next_batch_1]
for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"):
generations, next_batch = generator.decode(batches)
for g in generations:
tokens[g.request_id].append(g.tokens.ids[0])
batches = [next_batch]
# Verify we now only have one pending request
assert next_batch.size == 1
assert len(tokens[0]) == max_new_tokens
assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
# Verify we have the output for the first request
for g in generations:
if g.request_id == 0:
output = g.generated_text
assert output.text != ""
assert output.generated_tokens == max_new_tokens
generated_text = output.text
# Continue decoding until the end of the second request
for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert next_batch is None
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert tokens[0] == tokens[1]
assert output.text == generated_text

0 comments on commit 542a9bc

Please sign in to comment.