From d676bb989f6e6a9a12355afe98636bb690b6fa6a Mon Sep 17 00:00:00 2001 From: pgrundmann Date: Thu, 30 May 2024 14:59:52 +0200 Subject: [PATCH 1/5] Add vllm mask cache --- outlines/integrations/vllm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index 6ed56d71b..fcb2161d0 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -78,6 +78,7 @@ def __init__(self, regex_string: str, llm: "LLM"): "`tokenizer` attribute or a `get_tokenizer` method." ) tokenizer = adapt_tokenizer(tokenizer=tokenizer) + self.mask_cache = {} self.fsm = RegexGuide(regex_string, tokenizer) self._fsm_state: DefaultDict[int, int] = defaultdict(int) @@ -111,8 +112,13 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: state=self._fsm_state[seq_id] ).tokens - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) - mask[allowed_tokens] = 0 + cache_key = hash(tuple(allowed_tokens[:2048])) + if cache_key not in self.mask_cache: + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask[allowed_tokens] = 0 + self.mask_cache[cache_key] = mask + else: + mask = self.mask_cache[cache_key] biased_scores = scores + mask return biased_scores From 1385324647d80d1520a6adc3a1c7b8447194c154 Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Fri, 14 Jun 2024 10:23:43 +0200 Subject: [PATCH 2/5] Fix broken commits --- outlines/integrations/vllm.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index fcb2161d0..b781f7316 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -108,17 +108,18 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: state=self._fsm_state[last_seq_id], token_id=last_token ) - allowed_tokens = self.fsm.get_next_instruction( - state=self._fsm_state[seq_id] - ).tokens - - cache_key = hash(tuple(allowed_tokens[:2048])) - if cache_key not in self.mask_cache: - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + state_id = self._fsm_state[seq_id] + if state_id not in self.mask_cache: + allowed_tokens = self.fsm.get_next_instruction( + state=self._fsm_state[seq_id] + ).tokens + mask = torch.full((scores.shape[-1],), -math.inf) mask[allowed_tokens] = 0 - self.mask_cache[cache_key] = mask + mask = mask.pin_memory() + self.mask_cache[state_id] = mask else: - mask = self.mask_cache[cache_key] + mask = self.mask_cache[state_id] + mask = mask.to(device=scores.device, non_blocking=True) biased_scores = scores + mask return biased_scores From c866632389459d8b0296af9b27b47391f28f24d5 Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Fri, 14 Jun 2024 10:59:39 +0200 Subject: [PATCH 3/5] Fix pre-commit checks, added dict type annotation to mask_cache --- benchmarks/bench_vllm_mask_cache.py | 81 +++++++++++++++++++++++++++++ outlines/integrations/vllm.py | 4 +- 2 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 benchmarks/bench_vllm_mask_cache.py diff --git a/benchmarks/bench_vllm_mask_cache.py b/benchmarks/bench_vllm_mask_cache.py new file mode 100644 index 000000000..8d1ceeb24 --- /dev/null +++ b/benchmarks/bench_vllm_mask_cache.py @@ -0,0 +1,81 @@ +from outlines.caching import cache_disabled +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema + +from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 + +simple_schema = """{ + "$defs": { + "Armor": { + "enum": ["leather", "chainmail", "plate"], + "title": "Armor", + "type": "string" + } + }, + "properties": { + "name": {"maxLength": 10, "title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + "armor": {"$ref": "#/$defs/Armor"}, + "strength": {"title": "Strength", "type": "integer"}\ + }, + "required": ["name", "age", "armor", "strength"], + "title": "Character", + "type": "object" + }""" + + +complex_schema = """{ + "$schema": "http://json-schema.org/draft-04/schema#", + "title": "Schema for a recording", + "type": "object", + "definitions": { + "artist": { + "type": "object", + "properties": { + "id": {"type": "number"}, + "name": {"type": "string"}, + "functions": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["id", "name", "functions"] + } + }, + "properties": { + "id": {"type": "number"}, + "work": { + "type": "object", + "properties": { + "id": {"type": "number"}, + "name": {"type": "string"}, + "composer": {"$ref": "#/definitions/artist"} + } + }, + "recording_artists": { + "type": "array", + "items": {"$ref": "#/definitions/artist"} + } + }, + "required": ["id", "work", "recording_artists"] +}""" + +schemas = dict(simple_schema=simple_schema, complex_schema=complex_schema) + + +class JsonSchemaBenchmark: + params = schemas.keys() + + def setup(self, schema_name): + self.tokenizer = setup_tokenizer() + self.schema = schemas[schema_name] + ensure_numba_compiled(self.tokenizer) + + @cache_disabled() + def time_json_schema_to_regex(self, schema_name): + build_regex_from_schema(self.schema) + + @cache_disabled() + def time_json_schema_to_fsm(self, schema_name): + regex = build_regex_from_schema(self.schema) + RegexGuide(regex, self.tokenizer) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index b781f7316..e7bda5884 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -27,7 +27,7 @@ import math from collections import defaultdict -from typing import TYPE_CHECKING, DefaultDict, List, Optional, Type, Union +from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Type, Union import torch from pydantic import BaseModel @@ -78,7 +78,7 @@ def __init__(self, regex_string: str, llm: "LLM"): "`tokenizer` attribute or a `get_tokenizer` method." ) tokenizer = adapt_tokenizer(tokenizer=tokenizer) - self.mask_cache = {} + self.mask_cache: Dict[int, torch.Tensor] = {} self.fsm = RegexGuide(regex_string, tokenizer) self._fsm_state: DefaultDict[int, int] = defaultdict(int) From 49f485b49bcf81f5435690cd9da06110ed4594df Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Fri, 14 Jun 2024 13:07:38 +0200 Subject: [PATCH 4/5] Remove benchmark --- benchmarks/bench_vllm_mask_cache.py | 81 ----------------------------- 1 file changed, 81 deletions(-) delete mode 100644 benchmarks/bench_vllm_mask_cache.py diff --git a/benchmarks/bench_vllm_mask_cache.py b/benchmarks/bench_vllm_mask_cache.py deleted file mode 100644 index 8d1ceeb24..000000000 --- a/benchmarks/bench_vllm_mask_cache.py +++ /dev/null @@ -1,81 +0,0 @@ -from outlines.caching import cache_disabled -from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema - -from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 - -simple_schema = """{ - "$defs": { - "Armor": { - "enum": ["leather", "chainmail", "plate"], - "title": "Armor", - "type": "string" - } - }, - "properties": { - "name": {"maxLength": 10, "title": "Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "armor": {"$ref": "#/$defs/Armor"}, - "strength": {"title": "Strength", "type": "integer"}\ - }, - "required": ["name", "age", "armor", "strength"], - "title": "Character", - "type": "object" - }""" - - -complex_schema = """{ - "$schema": "http://json-schema.org/draft-04/schema#", - "title": "Schema for a recording", - "type": "object", - "definitions": { - "artist": { - "type": "object", - "properties": { - "id": {"type": "number"}, - "name": {"type": "string"}, - "functions": { - "type": "array", - "items": {"type": "string"} - } - }, - "required": ["id", "name", "functions"] - } - }, - "properties": { - "id": {"type": "number"}, - "work": { - "type": "object", - "properties": { - "id": {"type": "number"}, - "name": {"type": "string"}, - "composer": {"$ref": "#/definitions/artist"} - } - }, - "recording_artists": { - "type": "array", - "items": {"$ref": "#/definitions/artist"} - } - }, - "required": ["id", "work", "recording_artists"] -}""" - -schemas = dict(simple_schema=simple_schema, complex_schema=complex_schema) - - -class JsonSchemaBenchmark: - params = schemas.keys() - - def setup(self, schema_name): - self.tokenizer = setup_tokenizer() - self.schema = schemas[schema_name] - ensure_numba_compiled(self.tokenizer) - - @cache_disabled() - def time_json_schema_to_regex(self, schema_name): - build_regex_from_schema(self.schema) - - @cache_disabled() - def time_json_schema_to_fsm(self, schema_name): - regex = build_regex_from_schema(self.schema) - RegexGuide(regex, self.tokenizer) From 1a7079b97ab532a60cf5b681d3da2f99362f6335 Mon Sep 17 00:00:00 2001 From: Paul Grundmann Date: Fri, 14 Jun 2024 13:08:14 +0200 Subject: [PATCH 5/5] Fix move to scores device --- outlines/integrations/vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index e7bda5884..2a5f26e35 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -119,7 +119,7 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: self.mask_cache[state_id] = mask else: mask = self.mask_cache[state_id] - mask = mask.to(device=scores.device, non_blocking=True) + mask = mask.to(device=scores.device, non_blocking=True) biased_scores = scores + mask return biased_scores