From 542a9bce4256075cc04066b11884edc130b060ae Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 4 Nov 2024 13:22:58 +0000 Subject: [PATCH] fix(generation): correct generation for batch size > 1 The sampling function passed to the Jetstream's generate method is JIT'ed. This made the generations incorrect, because it was not detecting that slots were modified, thus producing incorrect results for batch size >1. The solution is to do a `jax.pure_callback` the will go back to python and to check the objects values. At that point another issue appeared: the sampling functions did not work in this context, getting stuck forever on the call. The workaround is to force JIT on this last part. --- .../jetstream_pt_support/generator.py | 108 +++++++------- .../jetstream_pt_support/token_selector.py | 4 +- .../tests/test_tinyllama.py | 141 ++++++++++++++++++ 3 files changed, 195 insertions(+), 58 deletions(-) create mode 100644 text-generation-inference/tests/test_tinyllama.py diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 6d19da6..ae0aa1c 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -241,7 +241,13 @@ def set(self, slot: Slot): self._curslot = slot def select(self, logits: jnp.ndarray) -> int: - return self._curslot.select(logits) + def _inner_select(logits): + return self._curslot.select(logits) + token = jax.pure_callback( + _inner_select, + result_shape_dtypes=jax.ShapeDtypeStruct((), jnp.int32), + logits=logits) + return token class TpuGeneratorJetStream(Generator): """A Generator for models running on TPU, single threaded.""" @@ -266,9 +272,7 @@ def __init__( tokenizer.truncation_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - - # Slots are empty to begin with, they will be populated as new batches arrive - self.slots = [] + self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)] self.batch_id = 0 # Note: this index will _never_ be decremented, and that's fine. self.slot_index = 0 @@ -366,13 +370,11 @@ def warmup(self, batch: Batch) -> int: seq_len = self.engine.env.seq_len return batch_size * seq_len - def _get_slot_id(self): - """Get the next available slot id.""" - batch_size = self.engine.env.batch_size - used_ids = [slot.id for slot in self.slots if slot.state != Slot.State.EMPTY] - for i in range(batch_size): - if i not in used_ids: - return i + def _get_slot(self): + """Get the next available slot.""" + for slot in self.slots: + if slot.state == Slot.State.EMPTY: + return slot # if we reach this point, all slots were used - this should not happen raise ValueError("All slots are used, but we should have stopped earlier") @@ -422,14 +424,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: A list of `Generation` for each request and a `CachedBatch` containing all pending requests. """ - slots = {state: [] for state in Slot.State} - for slot in self.slots: - slots[slot.state].append(slot) - len_active_slots = len(slots[Slot.State.READY]) - # Delete all empty slots, no need to have them anymore - empty_slots = slots[Slot.State.EMPTY] - for slot in empty_slots: - self.slots.remove(slot) + active_slots = [slot for slot in self.slots if slot.state == Slot.State.READY] + len_active_slots = len(active_slots) + len_requests = len(batch.requests) model_batch_size = self.model.config.batch_size if model_batch_size is not None and model_batch_size < len_active_slots + len_requests: @@ -444,10 +441,10 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Assign each request to an empty slot logger.debug(f"Prefilling {len_requests} new request(s) adding to {len_active_slots} active slot(s)") generations = [] - + prefilled_active_slots = [] for request in batch.requests: # Dynamically create a new slot for each request - slot = Slot(self._get_slot_id(), self.tokenizer) + slot = self._get_slot() self.prefill_slot.set(slot) self.slot_index += 1 slot.assign(self.batch_id, request, self.model.generation_config) @@ -479,29 +476,34 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: self._post_generate(slot, next_token, generations) if not slot.empty: - # append current to list of active slots - self.slots.append(slot) - len_active_slots += 1 - - batch = None - if len_active_slots > 0: - # Whatever initial batch these requests came from, we always return all pending requests in a single batch - request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] - batch = self._cached_batch(self.batch_id, request_ids) - else: - logger.debug("No more pending requests") + prefilled_active_slots.append(slot) + + cached_batch = self._cached_batch(self.batch_id, prefilled_active_slots) self.batch_id += 1 logger.debug("Model ready for decoding") - return generations, batch + return generations, cached_batch def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray: pad_token_id = self.tokenizer.pad_token_id batch_size = logits.shape[0] - tokens = jnp.full((batch_size, 1), pad_token_id) - for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): - # Every slot might have a different selection criteria, so we are obliged to call select in a loop - next_token = slot.select(logits) - tokens = tokens.at[slot.id].set(next_token) + + def inner_select_from_slots(logits, batch_size): + tokens = jnp.full((batch_size, 1), pad_token_id) + for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): + # Every slot might have a different selection criteria, so we are obliged to call select in a loop + next_token = slot.select(logits[slot.id : slot.id + 1, :]) + tokens = tokens.at[slot.id].set(next_token) + return tokens + + # NOTE: The above code has been written in a non-functional way, thus causing problems when the function is + # JIT'ed. The workaround is to use jax.pure_callback, but this might cause performance issues: if those are + # observed, the code should be modified accordingly. + tokens = jax.pure_callback( + inner_select_from_slots, + result_shape_dtypes=jax.ShapeDtypeStruct((batch_size, 1), jnp.int32), + logits=logits, + batch_size=batch_size, + ) return tokens def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: @@ -528,8 +530,6 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa # just carry on with decoding. We adopt the id of the first # batch in the list as our next batch id. next_batch_id = batches[0].id - if len(batches) > 1: - logger.warning("Unexpected multiple batches received, only the first one will be processed.") request_ids = [] for batch in batches: request_ids += batch.request_ids @@ -537,7 +537,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa for slot in self.slots: if slot.state == slot.State.READY and slot.request_id not in request_ids: cleared_request_ids.append(slot.request_id) - self.slots.remove(slot) + slot.clear() if len(cleared_request_ids) > 0: logger.info(f"Clearing slot for requests {cleared_request_ids} as they are not requested.") active_slots = [slot for slot in self.slots if slot.state == slot.State.READY] @@ -551,7 +551,6 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa select_fn = jax.tree_util.Partial(self._select_from_slots) self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) - newly_empty = [] generations = [] for slot in active_slots: # Get the next token. @@ -564,20 +563,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa raise ValueError("Unexpected Slot is not ready for decoding") self._post_generate(slot, next_token, generations) - if slot.empty: - newly_empty.append(slot) - - # Remove empty slots - for slot in newly_empty: - self.slots.remove(slot) - batch = None - if len(self.slots) > 0: - # Whatever initial batch these requests came from, we always return all pending requests in a single batch - request_ids = [slot.request_id for slot in self.slots if slot.state == Slot.State.READY] - batch = self._cached_batch(next_batch_id, request_ids) - else: - logger.debug("No more pending requests") - return generations, batch + + cached_batch = self._cached_batch(next_batch_id, active_slots) + return generations, cached_batch def _post_generate(self, slot: Slot, next_token: int, generations: List[Generation]) -> None: """Post-generate a slot after the generation has been completed. @@ -625,7 +613,13 @@ def _post_generate(self, slot: Slot, next_token: int, generations: List[Generati ) ) - def _cached_batch(self, batch_id: int, request_ids: List): + def _cached_batch(self, batch_id: int, active_slots: List): + """Create a CachedBatch from the active slots. + """ + request_ids = [slot.request_id for slot in active_slots if slot.state == Slot.State.READY] + if len(request_ids) == 0: + logger.debug("No more pending requests") + return None size = len(request_ids) max_tokens = size * self.model.config.sequence_length return CachedBatch(id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index 6f300cb..ee7e8ce 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -58,6 +58,8 @@ def __init__( # Seed needs to fit a 64-bit integer, so we modulo it in case is bigger (that can happen!) seed = seed % jnp.iinfo(jnp.int64).max self.key = jax.random.PRNGKey(seed) + # TODO: it seems the sample method needs to be JIT'ed, otherwise for some reason it seems to get stuck. + self.sample = jax.jit(self._sample) @classmethod def create( @@ -175,7 +177,7 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: """ scores = self.logits_processor(input_ids, logits) if self.mode == GenerationMode.SAMPLE: - return self._sample(scores) + return self.sample(scores) else: return jnp.argmax(scores, axis=-1) diff --git a/text-generation-inference/tests/test_tinyllama.py b/text-generation-inference/tests/test_tinyllama.py new file mode 100644 index 0000000..40fc2b1 --- /dev/null +++ b/text-generation-inference/tests/test_tinyllama.py @@ -0,0 +1,141 @@ + +import pytest +from helpers import create_request, prepare_model +from text_generation_server.auto_generator import AutoGenerator +from text_generation_server.pb.generate_pb2 import Batch +from tqdm import tqdm + + +MODEL_ID = "Maykeye/TinyLLama-v0" +SEQUENCE_LENGTH = 256 + + +@pytest.fixture(scope="module") +def model_path(): + return prepare_model(MODEL_ID, SEQUENCE_LENGTH) + + +def test_info(model_path): + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=1) + info = generator.info + assert info.requires_padding is True + assert info.device_type == "meta" + assert info.window_size == 0 + assert info.speculate == 0 + + +@pytest.mark.parametrize( + "input_text, token_id, token_text, do_sample", + [ + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 347, + " The", + False, + ], + [ + "It was a bright cold day in April, and the clocks were striking thirteen.", + 13, + "\n", + True, + ], + ], + ids=["greedy", "sample"], +) +@pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) +def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): + generator = AutoGenerator.from_pretrained(model_path, revision="", max_batch_size=batch_size, max_sequence_length=SEQUENCE_LENGTH) + requests = [] + max_new_tokens = 20 + for i in range(batch_size): + requests.append(create_request(id=0, inputs=input_text, do_sample=do_sample, max_new_tokens=max_new_tokens)) + # Let's be pessimistic when estimating max_tokens + batch_size * (len(input_text) + max_new_tokens) + batch = Batch(id=0, requests=requests, size=batch_size, max_tokens=batch_size * SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == batch_size + # Whatever was passed as max_tokens, the server will correct it + # because of static batching + assert next_batch.max_tokens == batch_size * SEQUENCE_LENGTH + assert len(generations) == batch_size + for g in generations: + tokens = g.tokens + assert tokens.ids == [token_id] + assert tokens.texts == [token_text] + # Redo but with greedy + batch.requests[0].parameters.do_sample = False + generator.clear() + generations, next_batch = generator.prefill(batch) + print(generations[0]) + batch.requests[0].parameters.do_sample = True + generator.clear() + generations, next_batch = generator.prefill(batch) + print(generations[0]) + + +def test_decode_multiple(model_path): + generator = AutoGenerator.from_pretrained(model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH) + input_text = "Once upon a time" + max_new_tokens = 20 + # Prefill a single request, remembering the generated token + tokens = {0: [], 1: []} + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens) + batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + assert next_batch.size == 1 + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == 1 + # Decode a few tokens + gen_tokens = 4 + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"): + generations, next_batch = generator.decode([next_batch]) + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == gen_tokens + assert next_batch.size == 1 + # Add a second request + request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens) + batch = Batch(id=1, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch_1 = generator.prefill(batch) + assert next_batch_1.size == 1 + # We should have generated only a single token + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert len(tokens[0]) == gen_tokens + assert len(tokens[1]) == 1 + # Decode more tokens until we reach the maximum for the first request + batches = [next_batch, next_batch_1] + for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"): + generations, next_batch = generator.decode(batches) + for g in generations: + tokens[g.request_id].append(g.tokens.ids[0]) + batches = [next_batch] + # Verify we now only have one pending request + assert next_batch.size == 1 + assert len(tokens[0]) == max_new_tokens + assert len(tokens[1]) == max_new_tokens - gen_tokens + 1 + # Verify we have the output for the first request + for g in generations: + if g.request_id == 0: + output = g.generated_text + assert output.text != "" + assert output.generated_tokens == max_new_tokens + generated_text = output.text + # Continue decoding until the end of the second request + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"): + generations, next_batch = generator.decode([next_batch]) + assert len(generations) == 1 + g = generations[0] + tokens[g.request_id].append(g.tokens.ids[0]) + assert next_batch is None + output = generations[0].generated_text + assert output.generated_tokens == max_new_tokens + assert tokens[0] == tokens[1] + assert output.text == generated_text