diff --git a/benchmarks/bench_cfg_guide.py b/benchmarks/bench_cfg_guide.py index 8f6de914a..14dc31c73 100644 --- a/benchmarks/bench_cfg_guide.py +++ b/benchmarks/bench_cfg_guide.py @@ -38,23 +38,31 @@ def setup(self, grammar_name): ) @staticmethod - def _run_random_cfg(guide): + def _run_random_cfg(guide, rejection_sampling=True): state = guide.initial_state token_ids = list(guide.tokenizer.vocabulary.values()) for i in range(40): # simulate ordering of logits top prob to lowest prob random.shuffle(token_ids) # simulate sampling and state update - next_token_id = next(guide.iter_valid_token_ids(state, token_ids)) - state = guide.get_next_state(state, next_token_id) + if rejection_sampling: + next_token_id = next(guide.iter_valid_token_ids(state, token_ids)) + state = guide.get_next_state(state, next_token_id) + else: + next_token_id = random.choice(guide.get_next_instruction(state).tokens) + state = guide.get_next_state(state, next_token_id) @cache_disabled() def time_cfg_guide_setup(self, grammar_name): CFGGuide(benched_grammars[grammar_name], self.tokenizer) + @cache_disabled() + def time_cfg_guide_run_rejection_sampling(self, grammar): + self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=True) + @cache_disabled() def time_cfg_guide_run(self, grammar): - self._run_random_cfg(self.prebuilt_cfg_guide) + self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=False) @cache_disabled() def peakmem_cfg_guide_run(self, grammar):