Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Jun 10, 2024
1 parent 59bb061 commit cf4ece1
Show file tree
Hide file tree
Showing 16 changed files with 60 additions and 44 deletions.
7 changes: 5 additions & 2 deletions blendsql/_program.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Tuple, Callable
from typing import Tuple
import inspect
from outlines.models import LogitsGenerator
import ast
import textwrap
import logging
from colorama import Fore
from abc import abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,6 +56,7 @@ def __new__(
):
return self.__call__(self, model, **kwargs)

@abstractmethod
def __call__(self, model: Model, *args, **kwargs) -> Tuple[str, str]:
"""Logic for formatting prompt and calling the underlying model.
Should return tuple of (response, prompt).
Expand All @@ -62,7 +65,7 @@ def __call__(self, model: Model, *args, **kwargs) -> Tuple[str, str]:


def return_ollama_response(
logits_generator: Callable, prompt, **kwargs
logits_generator: LogitsGenerator, prompt, **kwargs
) -> Tuple[str, str]:
"""Helper function to work with Ollama models,
since they're not recognized in the Outlines ecosystem.
Expand Down
10 changes: 7 additions & 3 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,13 @@ def replace_join_with_ingredient_multiple_ingredient(
if anon_child_node.name == ingredient_name:
join_alias = ingredient_alias
continue
to_return.append(
anon_child_node.parent.parent.parent.sql(dialect=FTS5SQLite)
)
# Traverse and get the whole ingredient
# We need to go up 3 parents
_parent = anon_child_node
for _ in range(3):
_parent = _parent.parent
assert isinstance(_parent, exp.Expression)
to_return.append(_parent.sql(dialect=FTS5SQLite))
if len(to_return) == 0:
return node
# temp_uuid is used to ensure a partial query that is parse-able by sqlglot
Expand Down
3 changes: 1 addition & 2 deletions blendsql/db/_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def __init__(self, db_path: str):
raise ImportError(
"Please install psycopg2 with `pip install psycopg2-binary`!"
) from None
self._raw_db_path = db_path
db_url: URL = make_url(f"postgresql+psycopg2://{self._raw_db_path}")
db_url: URL = make_url(f"postgresql+psycopg2://{db_path}")
if db_url.username is None:
logging.warning(
Fore.RED
Expand Down
1 change: 0 additions & 1 deletion blendsql/db/_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
class SQLAlchemyDatabase(Database):
db_url: URL = attrib()

_raw_db_path: str = attrib(init=False)
engine: Engine = attrib(init=False)
con: Connection = attrib(init=False)

Expand Down
6 changes: 3 additions & 3 deletions blendsql/db/_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from sqlalchemy.engine import make_url, URL
from functools import cached_property
from typing import Dict

from .utils import double_quote_escape
from ._sqlalchemy import SQLAlchemyDatabase
Expand All @@ -18,8 +19,7 @@ class SQLite(SQLAlchemyDatabase):
"""

def __init__(self, db_path: str):
self._raw_db_path = Path(db_path).resolve()
db_url: URL = make_url(f"sqlite:///{self._raw_db_path}")
db_url: URL = make_url(f"sqlite:///{Path(db_path).resolve()}")
super().__init__(db_url=db_url)

def has_temp_table(self, tablename: str) -> bool:
Expand All @@ -36,7 +36,7 @@ def sqlglot_schema(self) -> dict:
>>> db.sqlglot_schema
{"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}}
"""
schema = {}
schema: Dict[str, dict] = {}
for tablename in self.tables():
schema[f'"{double_quote_escape(tablename)}"'] = {}
for _, row in self.execute_to_df(
Expand Down
4 changes: 2 additions & 2 deletions blendsql/db/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import pandas as pd
from typing import Callable
from attr import attrs, attrib
from functools import partial


@attrs(frozen=True)
Expand All @@ -12,7 +12,7 @@ class LazyTable:
"""

tablename: str = attrib()
collect: partial[..., pd.DataFrame] = attrib()
collect: Callable[..., pd.DataFrame] = attrib()

def __str__(self):
return self.tablename
Expand Down
3 changes: 2 additions & 1 deletion blendsql/grammars/minEarley/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def pattern_to_candidates(pattern: Pattern):

# TODO: handle case where no candidate is found
if len(candidate_terminals) == 0:
candidate_terminals = [""]
candidate_terminals = {""}

elif isinstance(e, UnexpectedEOF):
candidate_terminals = set()
for terminal_name in e.expected:
Expand Down
4 changes: 2 additions & 2 deletions blendsql/grammars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def load_cfg_parser(
ingredient_type_to_function_type[ingredient.ingredient_type].append(
ingredient.__name__
)
cfg_grammar = cfg_grammar.substitute(
cfg_grammar_str: str = cfg_grammar.substitute(
blendsql_join_functions=format_ingredient_names_to_lark(
blendsql_join_functions
),
Expand All @@ -62,7 +62,7 @@ def load_cfg_parser(
),
)
return EarleyParser(
grammar=cfg_grammar,
grammar=cfg_grammar_str,
start="start",
keep_all_tokens=True,
)
Expand Down
3 changes: 2 additions & 1 deletion blendsql/ingredients/builtin/map/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def run(
Args:
question: The question to map onto the values. Will also be the new column name
model: The Model (blender) we will make calls to.
values: The list of values to apply question to.
value_limit: Optional limit on the number of values to pass to the Model
example_outputs: If binary == False, this gives the Model an example of the output we expect.
output_type: One of 'numeric', 'string', 'bool'
Expand All @@ -177,7 +178,7 @@ def run(
logger.debug(f"Tablename {tablename} not in given table_to_title!")
else:
table_title = table_to_title[tablename]
split_results = []
split_results: List[Union[str, None]] = []
# Only use tqdm if we're in debug mode
context_manager = (
tqdm(
Expand Down
10 changes: 5 additions & 5 deletions blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Dict, Union, Optional, Set, List, Tuple
from typing import Dict, Union, Optional, Set, Tuple
import pandas as pd
import outlines
import re
Expand All @@ -17,7 +17,7 @@ def __call__(
model: Model,
question: str,
context: Optional[pd.DataFrame] = None,
options: Optional[List[str]] = None,
options: Optional[Set[str]] = None,
long_answer: Optional[bool] = False,
table_title: Optional[str] = None,
max_tokens: Optional[int] = None,
Expand All @@ -26,13 +26,13 @@ def __call__(
prompt = ""
serialized_db = context.to_string() if context is not None else ""
prompt += "Answer the question for the table. "
modified_option_to_original = {}
if long_answer:
prompt += "Make the answer as concrete as possible, providing more context and reasoning using the entire table.\n"
else:
prompt += "Keep the answers as short as possible, without leading context. For example, do not say 'The answer is 2', simply say '2'.\n"
if options is not None:
# Add in title case, since this helps with selection
modified_option_to_original = {}
_options = copy.deepcopy(options)
# Below we check to see if our options have a unique first word
# sometimes, the model will generate 'Frank' instead of 'Frank Smith'
Expand Down Expand Up @@ -63,9 +63,9 @@ def __call__(
generator = outlines.generate.choice(
model.logits_generator, [re.escape(str(i)) for i in options]
)
response: str = generator(prompt, max_tokens=max_tokens)
_response: str = generator(prompt, max_tokens=max_tokens)
# Map from modified options to original, as they appear in DB
response: str = modified_option_to_original.get(response, response)
response: str = modified_option_to_original.get(_response, _response)
else:
if isinstance(model, OllamaLLM):
# Handle call to ollama
Expand Down
18 changes: 12 additions & 6 deletions blendsql/ingredients/ingredient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Set,
Collection,
Optional,
List,
)
import uuid
from colorama import Fore
Expand Down Expand Up @@ -161,7 +162,7 @@ def __call__(self, question: str, context: str, *args, **kwargs) -> tuple:
kwargs[IngredientKwarg.QUESTION] = question
mapped_values: Collection[Any] = self._run(*args, **kwargs)
self.num_values_passed += len(mapped_values)
df_as_dict = {colname: [], new_arg_column: []}
df_as_dict: Dict[str, list] = {colname: [], new_arg_column: []}
for value, mapped_value in zip(values, mapped_values):
df_as_dict[colname].append(value)
df_as_dict[new_arg_column].append(mapped_value)
Expand Down Expand Up @@ -356,24 +357,29 @@ def __call__(
# Unpack kwargs
aliases_to_tablenames: Dict[str, str] = kwargs.get("aliases_to_tablenames")

subtable = context
subtable: Union[pd.DataFrame, None] = None
if context is not None:
if isinstance(context, str):
tablename, colname = utils.get_tablename_colname(context)
# Optionally materialize a CTE
if tablename in self.db.lazy_tables:
subtable = self.db.lazy_tables.pop(tablename).collect()[colname]
subtable: pd.DataFrame = self.db.lazy_tables.pop(
tablename
).collect()[colname]
else:
subtable = self.db.execute_to_df(
subtable: pd.DataFrame = self.db.execute_to_df(
f'SELECT "{colname}" FROM "{tablename}"'
)
elif not isinstance(context, pd.DataFrame):
elif isinstance(context, pd.DataFrame):
subtable = context
else:
raise ValueError(
f"Unknown type for `identifier` arg in QAIngredient: {type(context)}"
)
if subtable.empty:
raise IngredientException("Empty subtable passed to QAIngredient!")
unpacked_options = options

unpacked_options: Union[List[str], None] = options
if options is not None:
if not isinstance(options, list):
try:
Expand Down
2 changes: 2 additions & 0 deletions blendsql/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def predict(self, program: Type[Program], **kwargs) -> str:
)
return self.cache.get(key)
# Modify fields used for tracking Model usage
response: str
prompt: str
response, prompt = program(model=self, **kwargs)
self.prompts.insert(-1, self.format_prompt(response, **kwargs))
self.num_calls += 1
Expand Down
4 changes: 2 additions & 2 deletions blendsql/models/remote/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class AzureOpenaiLLM(RemoteModel):
def __init__(
self,
model_name_or_path: str,
env: Optional[str] = None,
env: str = ".",
config: Optional[OpenAIConfig] = None,
caching: bool = True,
**kwargs
Expand Down Expand Up @@ -149,7 +149,7 @@ class OpenaiLLM(RemoteModel):
def __init__(
self,
model_name_or_path: str,
env: Optional[str] = None,
env: str = ".",
config: Optional[OpenAIConfig] = None,
caching: bool = True,
**kwargs
Expand Down
12 changes: 6 additions & 6 deletions blendsql/nl_to_blendsql/args.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
from typing import Optional, Collection, List, Union
from typing import Collection, List, Union
from dataclasses import dataclass, field


@dataclass
class NLtoBlendSQLArgs:
max_grammar_corrections: Optional[int] = field(
max_grammar_corrections: int = field(
default=0,
metadata={
"help": "Optional int defining maximum CFG-guided correction steps to be taken. This is based on the method in https://arxiv.org/pdf/2305.19234."
},
)

include_db_content_tables: Optional[Union[List[str], str]] = field(
include_db_content_tables: Union[List[str], str] = field(
default="all",
metadata={
"help": "Which database tables to add `num_serialized_rows` worth of content for in serialization."
},
)

num_serialized_rows: Optional[int] = field(
num_serialized_rows: int = field(
default=3,
metadata={
"help": "How many example rows to include in serialization of database"
},
)

use_tables: Optional[Collection[str]] = field(
use_tables: Collection[str] = field(
default=None,
metadata={"help": "Collection of tables to use in serialization to string"},
)

use_bridge_encoder: Optional[bool] = field(
use_bridge_encoder: bool = field(
default=True,
metadata={
"help": "Whether to use Bridge Content Encoder during input serialization"
Expand Down
13 changes: 7 additions & 6 deletions blendsql/nl_to_blendsql/nl_to_blendsql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, List, Tuple, Set, Optional, Union, Type
from typing import Collection, Tuple, Set, Optional, Union, Type
from textwrap import dedent
import outlines
from colorama import Fore
Expand Down Expand Up @@ -65,7 +65,8 @@ def __call__(
temperature=0.0,
)
generator = outlines.generate.text(model.logits_generator)
return (generator(prompt, stop_at=PARSER_STOP_TOKENS), prompt)
response: str = generator(prompt, stop_at=PARSER_STOP_TOKENS)
return (response, prompt)


class CorrectionProgram(Program):
Expand All @@ -76,7 +77,7 @@ def __call__(
serialized_db: str,
question: str,
partial_completion: str,
candidates: List[str],
candidates: Set[str],
**kwargs,
) -> Tuple[str, str]:
if isinstance(model, OllamaLLM):
Expand Down Expand Up @@ -111,7 +112,7 @@ def obtain_correction_pairs(
prediction: str, parser: EarleyParser
) -> Tuple[str, Set[str], int]:
"""
Returns a list of candidates in the form of (prefix, suffix).
Returns a list of candidates in the form of (prefix, candidates, error_position_index).
"""
try:
parser.parse(prediction)
Expand Down Expand Up @@ -258,15 +259,15 @@ def nl_to_blendsql(
prefix, candidates, pos_in_stream = obtain_correction_pairs(
program_prediction, parser
)
candidates = [i for i in candidates if i.strip() != ""]
# candidates = [i for i in candidates if i.strip() != ""]
if len(candidates) == 0:
logger.debug(
Fore.LIGHTMAGENTA_EX + "No correction pairs found" + Fore.RESET
)
return prefix
elif len(candidates) == 1:
# If we only have 1 candidate, no need to call LLM
selected_candidate = candidates[0]
selected_candidate = candidates.pop()
else:
# Generate the continuation candidate with the highest probability
selected_candidate = correction_model.predict(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_multi_table_blendsql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from blendsql import blend
from blendsql.db import SQLite, DuckDB
from blendsql.db import SQLite
from blendsql.utils import fetch_from_hub
from tests.utils import (
assert_equality,
Expand All @@ -14,7 +14,7 @@

databases = [
SQLite(fetch_from_hub("multi_table.db")),
DuckDB.from_sqlite(fetch_from_hub("multi_table.db")),
# DuckDB.from_sqlite(fetch_from_hub("multi_table.db")),
]


Expand Down

0 comments on commit cf4ece1

Please sign in to comment.