diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 3c622517f..daa77510b 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -5,7 +5,11 @@ from outlines.fsm.guide import RegexGuide # noqa: E402 from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 -from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 +from .common import ( # noqa: E402 + clear_outlines_cache, + ensure_numba_compiled, + setup_tokenizer, +) simple_schema = """{ "$defs": { @@ -70,6 +74,7 @@ class JsonSchemaBenchmark: params = schemas.keys() def setup(self, schema_name): + clear_outlines_cache() self.tokenizer = setup_tokenizer() self.schema = schemas[schema_name] ensure_numba_compiled(self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py index 1cf50b2a9..c0e9d87c4 100644 --- a/benchmarks/bench_numba_compile.py +++ b/benchmarks/bench_numba_compile.py @@ -5,13 +5,14 @@ import outlines -from .common import setup_tokenizer +from .common import clear_outlines_cache, setup_tokenizer outlines.disable_cache() class NumbaCompileBenchmark: def setup(self): + clear_outlines_cache() from outlines.fsm import regex self.tokenizer = setup_tokenizer() diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 4fcf1e7a7..efaea9e1f 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,6 +1,6 @@ import outlines -from .common import ensure_numba_compiled, setup_tokenizer +from .common import clear_outlines_cache, ensure_numba_compiled, setup_tokenizer outlines.disable_cache() @@ -23,6 +23,7 @@ class RegexGuideBenchmark: params = regex_samples.keys() def setup(self, pattern_name): + clear_outlines_cache() self.tokenizer = setup_tokenizer() ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] @@ -35,6 +36,7 @@ class MemoryRegexGuideBenchmark: params = ["simple_phone", "complex_span_constrained_relation_extraction"] def setup(self, pattern_name): + clear_outlines_cache() self.tokenizer = setup_tokenizer() ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] diff --git a/benchmarks/common.py b/benchmarks/common.py index 7d999ea9b..e0fe36f14 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1,9 +1,14 @@ from transformers import AutoTokenizer +import outlines.caching from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer +def clear_outlines_cache(): + outlines.caching.clear_cache() + + def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer)