Skip to content

Commit

Permalink
Moving to guidance for quicker runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Aug 26, 2024
1 parent e9f388c commit 1c706e5
Show file tree
Hide file tree
Showing 21 changed files with 225 additions and 525 deletions.
2 changes: 1 addition & 1 deletion benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

outlines.caching.clear_cache()

MODEL = TransformersLLM("hf-internal-testing/tiny-random-PhiForCausalLM", caching=False)
MODEL = TransformersLLM("HuggingFaceTB/SmolLM-135M", caching=False)
NUM_ITER_PER_QUERY = 5

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion blendsql/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __contains__(cls, item):

DEFAULT_ANS_SEP = ";"
DEFAULT_NAN_ANS = "-"
MAP_BATCH_SIZE = 5
MAP_BATCH_SIZE = 15


class IngredientType(str, Enum, metaclass=StrInMeta):
Expand Down
2 changes: 1 addition & 1 deletion blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def _blend(
if kwargs_dict.get(IngredientKwarg.REGEX, None) is not None:
logger.debug(
Fore.LIGHTBLACK_EX
+ f"Using regex '{kwargs_dict[IngredientKwarg.REGEX](1)}'"
+ f"Using regex '{kwargs_dict[IngredientKwarg.REGEX]}'"
+ Fore.RESET
)
if table_to_title is not None:
Expand Down
3 changes: 0 additions & 3 deletions blendsql/generate/__init__.py

This file was deleted.

22 changes: 0 additions & 22 deletions blendsql/generate/choice.py

This file was deleted.

28 changes: 0 additions & 28 deletions blendsql/generate/regex.py

This file was deleted.

151 changes: 60 additions & 91 deletions blendsql/ingredients/builtin/join/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import List, Optional, Tuple
import re
from colorama import Fore
import guidance

from blendsql.models import Model, LocalModel
from blendsql.models import Model
from blendsql._program import Program
from blendsql._logger import logger
from blendsql import _constants as CONST
from blendsql.ingredients.ingredient import JoinIngredient
from blendsql.utils import newline_dedent
from blendsql import generate


class JoinProgram(Program):
Expand All @@ -21,88 +18,69 @@ def __call__(
sep: str,
**kwargs,
) -> Tuple[str, str]:
prompt = ""
prompt += "You are a database expert in charge of performing a modified `LEFT JOIN` operation. This `LEFT JOIN` is based on a semantic criteria given by the user."
prompt += f"\nThe left and right value alignment should be separated by '{sep}', with each new `JOIN` alignment goin on a newline. If a given left value has no corresponding right value, give '-' as a response."
prompt += newline_dedent(
"""
Criteria: Join to same topics.
Left Values:
joshua fields
bob brown
ron ryan
Right Values:
ron ryan
colby mules
bob brown (ice hockey)
josh fields (pitcher)
Output:
joshua fields;josh fields (pitcher)
bob brown;bob brown (ice hockey)
ron ryan;ron ryan
---
"""
)
prompt += newline_dedent(
"""
Criteria: {}
m: guidance.models.Model = model.model_obj
with guidance.system():
m += "You are a database expert in charge of performing a modified `LEFT JOIN` operation. This `LEFT JOIN` is based on a semantic criteria given by the user."
m += f"\nThe left and right value alignment should be separated by '{sep}', with each new `JOIN` alignment goin on a newline. If a given left value has no corresponding right value, give '-' as a response."
m += newline_dedent(
"""
Criteria: Join to same topics.
Left Values:
{}
joshua fields
bob brown
ron ryan
Right Values:
{}
ron ryan
colby mules
bob brown (ice hockey)
josh fields (pitcher)
Output:
""".format(
join_criteria, "\n".join(left_values), "\n".join(right_values)
)
)
# Create this pattern on the fly, and not in infer_gen_constraints
# since it depends on what our left/right values are
regex = (
lambda num_repeats: "(({}){}({})\n)".format(
"|".join([re.escape(i) for i in left_values]),
CONST.DEFAULT_ANS_SEP,
"|".join(
[re.escape(i) for i in right_values] + [CONST.DEFAULT_NAN_ANS]
),
{
"joshua fields": "josh fields (pitcher)",
"bob brown": "bob brown (ice hockey)",
"ron ryan": "ron ryan"
}
---
"""
)
+ "{"
+ str(num_repeats)
+ "}"
)
max_tokens = (
len(
model.tokenizer.encode(
"".join(left_values)
+ "".join(right_values)
+ (CONST.DEFAULT_ANS_SEP * len(left_values)),
with guidance.user():
m += newline_dedent(
"""
Criteria: {}
Left Values:
{}
Right Values:
{}
Output:
""".format(
join_criteria, "\n".join(left_values), "\n".join(right_values)
)
)
if model.tokenizer is not None
else None
)
prompt = m._current_prompt()

if isinstance(model, LocalModel):
response = generate.regex(
model,
prompt=prompt,
regex=regex(len(left_values)),
max_tokens=max_tokens,
stop_at=["---"],
)
else:
response = generate.text(
model, prompt=prompt, max_tokens=max_tokens, stop_at=["---"]
)
logger.debug(Fore.CYAN + prompt + Fore.RESET)
logger.debug(Fore.LIGHTCYAN_EX + response + Fore.RESET)
return (response, prompt)
@guidance(stateless=True, dedent=False)
def make_predictions(lm, left_values, right_values):
lm += "{"
gen_f = guidance.select(options=right_values)
for idx, value in enumerate(left_values):
lm += (
f'\n\t"{value}": '
+ guidance.capture(gen_f, name=value)
+ ("," if idx + 1 != len(right_values) else "")
)
return lm

with guidance.assistant():
m += make_predictions(left_values=left_values, right_values=right_values)
return (m._variables, prompt)


class LLMJoin(JoinIngredient):
Expand All @@ -121,21 +99,12 @@ def run(
) -> dict:
if question is None:
question = "Join to same topics."
result = model.predict(
mapping = model.predict(
program=JoinProgram,
sep=CONST.DEFAULT_ANS_SEP,
left_values=left_values,
right_values=right_values,
join_criteria=question,
**kwargs,
)
# Post-process language model response
_result = result.split("\n")
mapping: dict = {}
for item in _result:
if CONST.DEFAULT_ANS_SEP in item:
k, v = item.rsplit(CONST.DEFAULT_ANS_SEP, 1)
if any(pred == CONST.DEFAULT_NAN_ANS for pred in {k, v}):
continue
mapping[k] = v
return mapping
return {k: v for k, v in mapping.items() if v != CONST.DEFAULT_NAN_ANS}
Loading

0 comments on commit 1c706e5

Please sign in to comment.