Skip to content

Commit

Permalink
Code refactor&Support custom extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Aug 20, 2024
1 parent 63f6d82 commit 5faff8a
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 10 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ print(logits_processor[0].formatters_captures)
# possible output:
# [{'digit': [<re.Match object; span=(0, 2), match='42'>, <re.Match object; span=(0, 2), match='42'>]}]
```
Note that only
[Rust regex's syntax](https://docs.rs/regex/latest/regex/#syntax) is supported, which notably
does not include arbitrary lookaheads.
### Json Generation
#### Pydantic Model
```python
Expand Down Expand Up @@ -200,6 +203,7 @@ print(logits_processor[0].formatters_captures)
# possible output:
# [{'json': 14}]
```

### Json Schema
You can use [pydantic's code generator](https://docs.pydantic.dev/latest/integrations/datamodel_code_generator/)
to generate pydantic models from json schema.
Expand Down
17 changes: 16 additions & 1 deletion src/formatron/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ def kbnf_representation(self) -> str:
"""
pass

@property
def nonterminal(self) -> str:
"""
Get the nonterminal representing the extractor.
"""
raise NotImplementedError("This extractor does not have a corresponding nonterminal.")

def __str__(self):
return f"${{{self.kbnf_representation}}}"
return f"${{{self.nonterminal}}}"


class LiteralExtractor(Extractor):
Expand Down Expand Up @@ -107,6 +114,10 @@ def extract(self, input_str: str) -> typing.Optional[tuple[str, re.Match | None]
def kbnf_representation(self) -> str:
return self._nonterminal

@property
def nonterminal(self) -> str:
return self._nonterminal


class ChoiceExtractor(Extractor):
"""
Expand Down Expand Up @@ -135,3 +146,7 @@ def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
@property
def kbnf_representation(self) -> str:
return self._nonterminal

@property
def nonterminal(self) -> str:
return self._nonterminal
26 changes: 23 additions & 3 deletions src/formatron/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def mask_logits(self, logits) -> typing.Any:
"""
pass

@abc.abstractmethod
def get_allowed_tokens_since_last_computation(self) -> typing.Sequence[int]:
"""
Get the allowed tokens since the last computation(in other words, the last call to `compute_allowed_tokens`).
:return: The allowed tokens.
"""
pass

@abc.abstractmethod
def is_completed(self) -> bool:
"""
Expand Down Expand Up @@ -265,12 +273,24 @@ def choose(self, *extractors: Extractor | str, capture_name: str = None) -> Choi

def _add_extractor(self, capture_name: str, extractor_type: str,
create_extractor: typing.Callable[[str], Extractor],
create_rule: typing.Callable[[str], str]):
create_rules: typing.Callable[[str], str]):
nonterminal = self._create_nonterminal(capture_name, extractor_type)
self._nonterminal_to_extractor[nonterminal] = create_extractor(nonterminal)
self._rules.append(create_rule(nonterminal))
self._rules.append(create_rules(nonterminal))
return self._nonterminal_to_extractor[nonterminal]

def extractor(self, create_extractor: typing.Callable[[str], Extractor],
create_rules: typing.Callable[[str], str], capture_name: str = None) -> Extractor:
"""
Create a custom extractor.
:param create_extractor: callable with signature (extractor_nonterminal: str)->Extractor that create the extractor. extractor_nonterminal is the auto-generated nonterminal reference for the extractor.
:param create_rules: callable with signature (extractor_nonterminal: str)->str that create the KBNF rules for the extractor. It is separated from create_extractor to allow more flexibility. For example, you can reuse the same extractor with different rules.
:param capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
"""
return self._add_extractor(capture_name, "extractor",
create_extractor,
create_rules)

def regex(self, regex: str, *, capture_name: str = None) -> RegexExtractor:
"""
Create a regex extractor.
Expand Down Expand Up @@ -313,7 +333,7 @@ def str(self, *, stop: typing.Union[str, list[str]] = None,
nonterminal_regex = "#'.*'"
else:
backslash = '\\'
capture_regex = f".*?(?:{'|'.join([i.replace(backslash, backslash*2) for i in map(re.escape, stop)])})"
capture_regex = f".*?(?:{'|'.join([i.replace(backslash, backslash * 2) for i in map(re.escape, stop)])})"
nonterminal_regex = f"#e'{capture_regex}'"
self._rules.append(f"{nonterminal} ::= {nonterminal_regex};")
self._nonterminal_to_extractor[nonterminal] = RegexExtractor(capture_regex, capture_name, nonterminal)
Expand Down
4 changes: 4 additions & 0 deletions src/formatron/grammar_generators/json_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def __init__(self, nonterminal: str, capture_name: typing.Optional[str],
def kbnf_representation(self) -> str:
return self._nonterminal

@property
def nonterminal(self) -> str:
return self._nonterminal

def extract(self, input_str: str) -> typing.Optional[tuple[str, schemas.schema.Schema]]:
"""
Extract a schema instance from a string.
Expand Down
6 changes: 4 additions & 2 deletions src/formatron/integrations/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from exllamav2 import ExLlamaV2Tokenizer, ExLlamaV2
from exllamav2.generator.base import ExLlamaV2Filter
from config import EngineGenerationConfig
from formatter import Formatter, FormatterBuilder
from formatter import FormatterBase, FormatterBuilder
from integrations._utils import get_original_characters

from formatron.formatter import FormatterBase


def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer) -> kbnf.Vocabulary:
"""
Expand Down Expand Up @@ -41,7 +43,7 @@ class FormatterFilter(ExLlamaV2Filter):
ExLlamaV2Filter that uses a formatter to mask logits.
"""

def __init__(self, model, tokenizer, formatter: Formatter,
def __init__(self, model, tokenizer, formatter: FormatterBase,
config: EngineGenerationConfig = None):
super().__init__(model, tokenizer)
self._formatter = formatter
Expand Down
4 changes: 2 additions & 2 deletions src/formatron/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList

from config import EngineGenerationConfig
from formatter import Formatter, FormatterBuilder
from formatter import FormatterBuilder, FormatterBase
from integrations._utils import get_original_characters


Expand Down Expand Up @@ -52,7 +52,7 @@ class FormattersLogitsProcessor(LogitsProcessor):
Logit processor that uses formatters to mask batch logits.
"""

def __init__(self, formatters: typing.Sequence[Formatter], eos_token_id: int,
def __init__(self, formatters: typing.Sequence[FormatterBase], eos_token_id: int,
configs: typing.Sequence[EngineGenerationConfig] = None):
self._formatters = formatters
self._eos_token_id = eos_token_id
Expand Down
4 changes: 2 additions & 2 deletions src/formatron/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm import LLM

from config import EngineGenerationConfig
from formatter import Formatter, FormatterBuilder
from formatter import FormatterBase, FormatterBuilder
from integrations._utils import get_original_characters


Expand All @@ -17,7 +17,7 @@ class FormattersLogitsProcessor:
Logit processor that uses formatters to mask batch logits.
"""

def __init__(self, formatters: typing.Sequence[Formatter], eos_token_id: int,
def __init__(self, formatters: typing.Sequence[FormatterBase], eos_token_id: int,
configs: typing.Sequence[EngineGenerationConfig] = None):
self._formatters = formatters
self._eos_token_id = eos_token_id
Expand Down

0 comments on commit 5faff8a

Please sign in to comment.