Skip to content

Commit

Permalink
Make formatter happy
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Aug 8, 2024
1 parent f22fbbf commit c6e3f02
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion dev-dependencies.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
doxypypy
git+https://github.com/Dan-wanna-M/doxypypy@master
2 changes: 1 addition & 1 deletion src/formatron/grammar_generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
This subpackage contains modules that define classes that generate KBNF grammars from schemas
and extract schema instances from strings adhering to the grammars.
"""
"""
2 changes: 0 additions & 2 deletions src/formatron/grammar_generators/grammar_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def generate(self, schema: typing.Type[schemas.schema.Schema], start_nonterminal
Returns:
The KBNF grammar string.
"""
pass

@abc.abstractmethod
def get_extractor(self, nonterminal: str, capture_name: typing.Optional[str],
Expand All @@ -41,4 +40,3 @@ def get_extractor(self, nonterminal: str, capture_name: typing.Optional[str],
Returns:
The extractor.
"""
pass
7 changes: 3 additions & 4 deletions src/formatron/grammar_generators/json_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class JsonGenerator(GrammarGenerator):
- Subclasses of collections.abc.Mapping[str,T] and typing.Mapping[str,T] where T is a supported type,
- Subclasses of collections.abc.Sequence[T] and typing.Sequence[T] where T is a supported type.
- tuple[T1,T2,...] where T1,T2,... are supported types. The order, type and number of elements will be preserved.
- typing.Literal[x1,x2,...] where x1, x2, ... are instances of int, string, bool or NoneType, or some other typing.Literal[y1,y2,...]
- typing.Literal[x1,x2,...] where x1, x2, ... are instances of int, string, bool or NoneType, or another typing.Literal[y1,y2,...]
- typing.Union[T1,T2,...] where T1,T2,... are supported types.
- schemas.Schema where all its fields' data types are supported. Recursive schema definitions are supported as well.
"""
Expand Down Expand Up @@ -91,9 +91,8 @@ def field_info(current: typing.Type, nonterminal: str):
if isinstance(current, schemas.schema.FieldInfo):
if current.required:
return "", [(current.annotation, nonterminal)]
else:
new_nonterminal = f"{nonterminal}_required"
return f"{nonterminal} ::= {new_nonterminal}?;\n", [(current.annotation, new_nonterminal)]
new_nonterminal = f"{nonterminal}_required"
return f"{nonterminal} ::= {new_nonterminal}?;\n", [(current.annotation, new_nonterminal)]
return None

def builtin_list(current: typing.Type, nonterminal: str):
Expand Down
2 changes: 1 addition & 1 deletion src/formatron/integrations/RWKV.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from formatter import FormatterBuilder


class PIPELINE_ARGS(rwkv.utils.PIPELINE_ARGS): # NOSONAR
class PIPELINE_ARGS(rwkv.utils.PIPELINE_ARGS):
"""
A wrapper for the arguments of the pipeline of RWKV.
"""
Expand Down
19 changes: 11 additions & 8 deletions src/formatron/integrations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ def _multiple_replace(replacements, text):
# Create a regular expression from the dictionary keys
regex = re.compile(b"(%s)" % b"|".join(map(re.escape, replacements.keys())))
# For each match, look-up corresponding value in dictionary
return regex.sub(lambda mo:replacements.get(mo.group(), b""), text)
return regex.sub(lambda mo: replacements.get(mo.group(), b""), text)


Processors = set[typing.Literal["sentencepiece", "<0xHH>", "dot_G"]]

def _autodetect_processors(vocab:typing.Dict[str, int]):

def _autodetect_processors(vocab: typing.Dict[str, int]):
result = set()
llama_present = any(i.find('<0xF0>')!=-1 for i in vocab.keys())
underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581')!=-1]) / len(vocab)) > 0.2
g_present = (len([1 for i in vocab.keys() if i.find('\u0120')!=-1]) / len(vocab)) > 0.2
llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys())
underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2
g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2
c_present = any(i.find('\u010A') != -1 for i in vocab.keys())
if llama_present:
result.add("<0xHH>")
Expand All @@ -26,7 +28,8 @@ def _autodetect_processors(vocab:typing.Dict[str, int]):
result.add("dot_C")
return result

def get_original_characters(vocab:typing.Dict[str, int]) -> typing.Dict[bytes, int]:

def get_original_characters(vocab: typing.Dict[str, int]) -> typing.Dict[bytes, int]:
old_char_to_new_char = {}
processors = _autodetect_processors(vocab)
for i in processors:
Expand All @@ -38,7 +41,7 @@ def get_original_characters(vocab:typing.Dict[str, int]) -> typing.Dict[bytes, i
old_char_to_new_char["\u010A".encode("UTF-8")] = b"\n"
elif i == "<0xHH>":
for j in range(256):
old_char_to_new_char[("<0x"+f"{j:02x}".upper()+">").encode("UTF-8")] = bytes([j])
old_char_to_new_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j])
else:
raise ValueError(f"{i} is not a valid processor name!")
new_vocab = {}
Expand All @@ -47,4 +50,4 @@ def get_original_characters(vocab:typing.Dict[str, int]) -> typing.Dict[bytes, i
k = k.encode("UTF-8")
new_k = _multiple_replace(old_char_to_new_char, k)
new_vocab[new_k] = token_id
return new_vocab
return new_vocab
5 changes: 3 additions & 2 deletions src/formatron/integrations/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class FormatterFilter(ExLlamaV2Filter):
"""
ExLlamaV2Filter that uses a formatter to mask logits.
"""

def __init__(self, model, tokenizer, formatter: Formatter,
config: EngineGenerationConfig = None):
super().__init__(model, tokenizer)
Expand Down Expand Up @@ -78,5 +79,5 @@ def next(self) -> typing.Tuple[typing.Set[int], typing.Set[int]]:
return pass_tokens, end_tokens

@property
def formatter_captures(self)->dict[str, typing.Any]:
return self._formatter.captures
def formatter_captures(self) -> dict[str, typing.Any]:
return self._formatter.captures
1 change: 1 addition & 0 deletions src/formatron/schemas/dict_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def annotation(self) -> typing.Type[typing.Any] | None:
"""
return self._annotation

@property
def required(self) -> bool:
"""
Check if the field is required for the schema.
Expand Down

0 comments on commit c6e3f02

Please sign in to comment.