diff --git a/jsonformer/main.py b/jsonformer/main.py index 9c13471..cebafe0 100644 --- a/jsonformer/main.py +++ b/jsonformer/main.py @@ -5,6 +5,8 @@ OutputNumbersTokens, StringStoppingCriteria, ) +from jsonformer.type_prefixes import get_prefix_tokens_for_types + from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer import json @@ -33,6 +35,8 @@ def __init__( self.json_schema = json_schema self.prompt = prompt + self.type_prefix_tokens = get_prefix_tokens_for_types(tokenizer) + self.number_logit_processor = OutputNumbersTokens(self.tokenizer, self.prompt) self.generation_marker = "|GENERATION|" @@ -147,6 +151,36 @@ def generate_object( obj[key] = self.generate_value(schema, obj, key) return obj + def choose_type_to_generate(self, possible_types: List[str]) -> str: + possible_types = list(set(possible_types)) # remove duplicates + self.debug("[choose_type_to_generate]", possible_types) + if len(possible_types) < 1: + raise ValueError(f"Union type must not be empty") + elif len(possible_types) == 1: + return possible_types[0] + + prompt = self.get_prompt() + input_tensor = self.tokenizer.encode(prompt, return_tensors="pt") + output = self.model.forward(input_tensor.to(self.model.device)) + logits = output.logits[0, -1] + + max_type = None + max_logit = -float("inf") + for possible_type in possible_types: + try: + prefix_tokens = self.type_prefix_tokens[possible_type] + except KeyError: + raise ValueError(f"Unsupported schema type: {possible_type}") + max_type_logit = logits[prefix_tokens].max() + if max_type_logit > max_logit: + max_type = possible_type + max_logit = max_type_logit + + if max_type is None: + raise Exception("Unable to find best type to generate for union type") + self.debug("[choose_type_to_generate]", max_type) + return max_type + def generate_value( self, schema: Dict[str, Any], @@ -154,6 +188,12 @@ def generate_value( key: Union[str, None] = None, ) -> Any: schema_type = schema["type"] + if isinstance(schema_type, list): + if key: + obj[key] = self.generation_marker + else: + obj.append(self.generation_marker) + schema_type = self.choose_type_to_generate(schema_type) if schema_type == "number": if key: obj[key] = self.generation_marker @@ -183,6 +223,8 @@ def generate_value( else: obj.append(new_obj) return self.generate_object(schema["properties"], new_obj) + elif schema_type == "null": + return None else: raise ValueError(f"Unsupported schema type: {schema_type}") diff --git a/jsonformer/type_prefixes.py b/jsonformer/type_prefixes.py new file mode 100644 index 0000000..c60eefd --- /dev/null +++ b/jsonformer/type_prefixes.py @@ -0,0 +1,32 @@ +from transformers import PreTrainedTokenizer +from typing import Dict, List +import re + +def is_number_prefix(s: str) -> bool: + return re.match(r"^[\-\d]+\.?[\d]*$", s) + +def is_boolean_prefix(s: str) -> bool: + return 'true'.startswith(s) or 'false'.startswith(s) + +def is_null_prefix(s: str) -> bool: + return 'null'.startswith(s) + +def is_string_prefix(s: str) -> bool: + return re.match(r'^"[^"]*"?$', s) + +def is_array_prefix(s: str) -> bool: + return re.match(r'^\[["\-\d\[{]*$', s) + +def is_object_prefix(s: str) -> bool: + return re.match(r'^\{"?$', s) + +def get_prefix_tokens_for_types(tokenizer: PreTrainedTokenizer) -> Dict[str, List[str]]: + vocab = tokenizer.vocab.items() + return { + "number": [v for k, v in vocab if is_number_prefix(k)], + "boolean": [v for k, v in vocab if is_boolean_prefix(k)], + "null": [v for k, v in vocab if is_null_prefix(k)], + "string": [v for k, v in vocab if is_string_prefix(k)], + "array": [v for k, v in vocab if is_array_prefix(k)], + "object": [v for k, v in vocab if is_object_prefix(k)], + }