Skip to content

Commit

Permalink
fix(Jetstream): improve randomness in generator
Browse files Browse the repository at this point in the history
When do_sample is set, the token selector should use some randomness,
but the jax key was never updated, leading to results very close to no
randomness used.
This fixes it, making do_sample more unpredictable (as expected).
  • Loading branch information
tengomucho committed Nov 4, 2024
1 parent a45d0f6 commit 8d1c3df
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def select(self, logits: jnp.ndarray) -> int:
logits = logits.reshape(1, -1)
return self._selector.select(self._tokens, logits)[0]

def update_rng_key(self):
self._selector.update_rng_key()

@property
def stopped(self) -> bool:
# unsqueeze tokens to avoid problems with stopping criteria
Expand Down Expand Up @@ -461,6 +464,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
seed=slot.seed,
)
slot.reset(truncated_input_ids, selector)
slot.update_rng_key()
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
Expand Down Expand Up @@ -540,6 +544,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
if len(active_slots) < len(request_ids):
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

# Update RNG in all slots
for slot in active_slots:
slot.update_rng_key()
# Use a custom function to select the next token for each slot
select_fn = jax.tree_util.Partial(self._select_from_slots)
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
else:
return jnp.argmax(scores, axis=-1)

def update_rng_key(self):
self.key, _ = jax.random.split(self.key)

def _sample(self, scores: jnp.ndarray) -> jnp.ndarray:
do_top_k = self.logits_warper.top_k > 0 and self.logits_warper.top_k < scores.shape[-1]
do_top_p = self.logits_warper.top_p < 1.0 and self.logits_warper.top_p > 0.0
Expand Down

0 comments on commit 8d1c3df

Please sign in to comment.