diff --git a/src/mistral_inference/generate.py b/src/mistral_inference/generate.py index c9e35c5..4d2dbe8 100644 --- a/src/mistral_inference/generate.py +++ b/src/mistral_inference/generate.py @@ -112,7 +112,7 @@ def generate( next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8) if eos_id is not None: - is_finished = is_finished ^ (next_token == eos_id).cpu() + is_finished = is_finished | (next_token == eos_id).cpu() if is_finished.all(): break