From 97843ee389f6d8175c6cfee287d3bfe785476e3f Mon Sep 17 00:00:00 2001 From: parkervg Date: Thu, 18 Jul 2024 10:37:55 -0400 Subject: [PATCH] starting to do tests with guidance --- blendsql/generate/regex.py | 16 +++++++++++++--- blendsql/models/local/_transformers.py | 7 ++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/blendsql/generate/regex.py b/blendsql/generate/regex.py index c728f26..886f2a4 100644 --- a/blendsql/generate/regex.py +++ b/blendsql/generate/regex.py @@ -1,6 +1,6 @@ from functools import singledispatch from typing import Optional, List, Union -import outlines +from guidance import capture, gen from ..models import Model, OllamaLLM @@ -13,8 +13,18 @@ def regex( max_tokens: Optional[int] = None, stop_at: Optional[Union[List[str], str]] = None, ) -> str: - generator = outlines.generate.regex(model.model_obj, regex_str=regex) - return generator(prompt, max_tokens=max_tokens, stop_at=stop_at) + print(regex) + regex = "((t|f|-);)((t|f|-);)((t|f|-);)((t|f|-);)" + res = ( + model.model_obj + + prompt + + capture( + gen(regex=regex, max_tokens=max_tokens or 1e10, stop=stop_at), name="res" + ) + ) + return res["res"] + # generator = outlines.generate.regex(model.model_obj, regex_str=regex) + # return generator(prompt, max_tokens=max_tokens, stop_at=stop_at) @regex.register(OllamaLLM) diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index 2c9357f..4ba4de4 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -49,11 +49,12 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs): def _load_model(self) -> ModelObj: # https://huggingface.co/blog/how-to-generate - from outlines.models import transformers + from guidance.models import Transformers - return transformers( + return Transformers( self.model_name_or_path, - model_kwargs=self.load_model_kwargs, + echo=False, + **self.load_model_kwargs, )