Skip to content

Commit

Permalink
Merge pull request #17 from parkervg/feature/improve-logging
Browse files Browse the repository at this point in the history
Improve Logging
  • Loading branch information
parkervg authored May 28, 2024
2 parents 91f0651 + 67e58c0 commit fead7c2
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 59 deletions.
48 changes: 21 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
from blendsql.utils import fetch_from_hub
import chainlit as cl
from dotenv import load_dotenv
import textwrap

import json
from chainlit import make_async
import json

from blendsql import blend, LLMQA, LLMJoin, LLMMap
from blendsql.models import AzureOpenaiLLM
from blendsql.models import OpenaiLLM
from blendsql.db import SQLite
from research.prompts.parser_program import ParserProgram
from research.utils.database import to_serialized
from blendsql.prompts import FewShot
from blendsql.nl_to_blendsql import nl_to_blendsql, NLtoBlendSQLArgs

load_dotenv(".env")
parser_model = OpenaiLLM("gpt-3.5-turbo")
blender_model = OpenaiLLM("gpt-3.5-turbo")

DB_PATH = "./research/db/hybridqa/2004_United_States_Grand_Prix_0.db"
db = SQLite(DB_PATH, check_same_thread=False)
db = SQLite(fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db"))
few_shot_prompt = open("./research/prompts/hybridqa/few_shot.txt").read()
ingredients_prompt = open("./research/prompts/hybridqa/ingredients.txt").read()
serialized_db = to_serialized(db, num_rows=3)


def fewshot_parse(model, **input_program_args):
# Dedent str args
for k, v in input_program_args.items():
if isinstance(v, str):
input_program_args[k] = textwrap.dedent(v)
res = model.predict(program=ParserProgram, **input_program_args)
return textwrap.dedent(res["result"])


@cl.on_message # this function will be called every time a user inputs a message in the UI
Expand All @@ -40,19 +29,24 @@ async def main(message: cl.Message):
Returns:
None.
"""
parser_model = AzureOpenaiLLM("gpt-4")
blender_model = AzureOpenaiLLM("gpt-4")

async with cl.Step(
name="Fewshot Parse to BlendSQL", language="sql", type="llm"
) as parser_step:
parser_step.input = message.content
blendsql_query = await make_async(fewshot_parse)(
model=parser_model,
ingredients_prompt=ingredients_prompt,
few_shot_prompt=few_shot_prompt,
serialized_db=serialized_db,
blendsql_query = await make_async(nl_to_blendsql)(
question=message.content,
db=db,
model=parser_model,
ingredients={LLMQA, LLMMap, LLMJoin},
few_shot_examples=FewShot.hybridqa,
args=NLtoBlendSQLArgs(
use_tables=["w"],
include_db_content_tables=["w"],
num_serialized_rows=3,
use_bridge_encoder=True,
),
verbose=False,
)

parser_step.output = blendsql_query
Expand Down
80 changes: 80 additions & 0 deletions blendsql/_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging

logging.basicConfig()


def msg_box(msg, indent=1, width=None, title=None):
"""Print message-box with optional title."""
lines = msg.split("\n")
space = " " * indent
if not width:
width = max(map(len, lines))
box = f'╔{"═" * (width + indent * 2)}\n' # upper_border
if title:
box += f"║{space}{title:<{width}}{space}\n" # title
box += f'║{space}{"-" * len(title):<{width}}{space}\n' # underscore
box += "".join([f"║{space}{line:<{width}}{space}\n" for line in lines])
box += f'╚{"═" * (width + indent * 2)}╝' # lower_border
return box


_fmt_console_debug = "%(message)s"
_fmt_console_info = (
"\u001b[0;34m" + "[ INFO ]" + "\u001b[0m" + " (%(name)s) %(message)s"
)
_fmt_console_warning = (
"\u001b[33;20m" + "[ WARN ]" + "\u001b[0m" + " (%(name)s) %(message)s"
)
_fmt_console_error = (
"\u001b[0;31m" + "[ ERRO ]" + "\u001b[0m" + " (%(name)s) %(message)s"
)
_fmt_console_critical = (
"\u001b[1;31m" + "[ CRIT ]" + " (%(name)s) %(message)s" + "\u001b[0m"
)


class _FormatterConsole(logging.Formatter):
def __init__(self, time: bool):
time_fmt = "[ %H:%M:%S ] " if time else ""
self.formatters = {
logging.DEBUG: logging.Formatter(fmt=_fmt_console_debug, datefmt=time_fmt),
logging.INFO: logging.Formatter(fmt=_fmt_console_info, datefmt=time_fmt),
logging.WARNING: logging.Formatter(
fmt=_fmt_console_warning, datefmt=time_fmt
),
logging.ERROR: logging.Formatter(fmt=_fmt_console_error, datefmt=time_fmt),
logging.CRITICAL: logging.Formatter(
fmt=_fmt_console_critical, datefmt=time_fmt
),
}

def format(self, record):
return self.formatters[record.levelno].format(record)


def consoleHandler(
time: bool = True, level: int = logging.INFO
) -> logging.StreamHandler:
console_handler = logging.StreamHandler()
console_handler.setFormatter(_FormatterConsole(time))
console_handler.setLevel(level)
return console_handler


class Logger(logging.Logger):
def __init__(
self,
name: str,
level: int = logging.INFO,
time: bool = True,
):
self._time = time
super().__init__(name)
self.addHandler(consoleHandler(time, level))

def getChild(self, name: str) -> logging.Logger:
child = Logger(self.name + "." + name, self.level, self._time)
return child


logger = Logger("blendsql", logging.DEBUG)
2 changes: 1 addition & 1 deletion blendsql/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
from .models import Model
from .utils import logger
from ._logger import logger


class Program:
Expand Down
15 changes: 12 additions & 3 deletions blendsql/_smoothie.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
import pandas as pd

from .ingredients import Ingredient
from .utils import tabulate
from .db.utils import truncate_df_content

"""
Defines output of an executed BlendSQL script
"""

class PrettyDataFrame(pd.DataFrame):
def __str__(self):
return tabulate(truncate_df_content(self, 50))

def __repr__(self):
return tabulate(truncate_df_content(self, 50))


@dataclass
Expand All @@ -26,3 +32,6 @@ class SmoothieMeta:
class Smoothie:
df: pd.DataFrame
meta: SmoothieMeta

def __post_init__(self):
self.df = PrettyDataFrame(self.df)
12 changes: 2 additions & 10 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from colorama import Fore
import string

from ._logger import logger, msg_box
from .utils import (
logger,
sub_tablename,
get_temp_session_table,
get_temp_subquery_table,
Expand Down Expand Up @@ -791,15 +791,7 @@ def _blend(
a, f'"{double_quote_escape(_get_temp_session_table(t))}"', query
)

logger.debug("")
logger.debug(
"**********************************************************************************"
)
logger.debug(Fore.LIGHTGREEN_EX + f"Final Query:\n{query}" + Fore.RESET)
logger.debug(
"**********************************************************************************"
)
logger.debug("")
logger.debug(Fore.LIGHTGREEN_EX + msg_box(f"Final Query:\n{query}") + Fore.RESET)

df = db.execute_to_df(query)

Expand Down
2 changes: 1 addition & 1 deletion blendsql/db/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pandas.io.sql import get_schema
from abc import abstractmethod

from ..utils import logger
from .._logger import logger
from .utils import double_quote_escape, truncate_df_content
from .bridge_content_encoder import get_database_matches

Expand Down
2 changes: 1 addition & 1 deletion blendsql/grammars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from string import Template
from colorama import Fore

from ..utils import logger
from .._logger import logger
from ..ingredients import Ingredient
from .._constants import IngredientType
from .minEarley.parser import EarleyParser
Expand Down
3 changes: 2 additions & 1 deletion blendsql/ingredients/builtin/join/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from blendsql.models import Model, LocalModel, OllamaLLM
from blendsql._program import Program, return_ollama_response
from blendsql._logger import logger
from blendsql import _constants as CONST
from blendsql.ingredients.ingredient import JoinIngredient
from blendsql.utils import logger, newline_dedent
from blendsql.utils import newline_dedent


class JoinProgram(Program):
Expand Down
5 changes: 3 additions & 2 deletions blendsql/ingredients/builtin/map/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from tqdm import tqdm
import outlines

from blendsql.utils import logger, newline_dedent
from blendsql.utils import newline_dedent
from blendsql._logger import logger
from blendsql.models import Model, LocalModel, RemoteModel, OpenaiLLM, OllamaLLM
from ast import literal_eval
from blendsql import _constants as CONST
Expand Down Expand Up @@ -128,7 +129,7 @@ def __call__(
temperature=0.0,
)
generator = outlines.generate.text(model.logits_generator)
return (generator(prompt, max_tokens=max_tokens), prompt)
return (generator(prompt, max_tokens=max_tokens, stop_at="\n"), prompt)


class LLMMap(MapIngredient):
Expand Down
5 changes: 2 additions & 3 deletions blendsql/models/_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
from typing import Any, List, Optional, Type
import pandas as pd
from attr import attrib, attrs
Expand All @@ -13,7 +12,7 @@
from abc import abstractmethod
from outlines.models import LogitsGenerator

from ..utils import logger
from .._logger import logger
from .._program import Program, program_to_str
from .._constants import IngredientKwarg
from ..db.utils import truncate_df_content
Expand Down Expand Up @@ -143,7 +142,7 @@ def _create_key(self, program: Program, **kwargs) -> str:
[
(k, sorted(v) if isinstance(v, set) else v)
for k, v in kwargs.items()
if not isinstance(v, functools.partial)
if not callable(v)
]
)
)
Expand Down
26 changes: 22 additions & 4 deletions blendsql/models/local/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,39 @@ class LlamaCppLLM(LocalModel):
"""

def __init__(
self, model_name_or_path: str, filename: str, caching: bool = True, **kwargs
self,
model_name_or_path: str,
filename: str,
hf_repo_with_config: str = None,
caching: bool = True,
**kwargs
):
if not _has_llama_cpp:
raise ImportError(
"Please install llama_cpp with `pip install llama-cpp-python`!"
) from None
from llama_cpp import llama_tokenizer

self._llama_tokenizer = None
if hf_repo_with_config:
self._llama_tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
hf_repo_with_config
)

super().__init__(
model_name_or_path=model_name_or_path,
# TODO: how to get llama_cpp tokenizer?
tokenizer=None,
tokenizer=self._llama_tokenizer.hf_tokenizer
if self._llama_tokenizer is not None
else None,
requires_config=False,
load_model_kwargs=kwargs | {"filename": filename},
caching=caching,
)

def _load_model(self, filename: str, **kwargs) -> LogitsGenerator:
return llamacpp(self.model_name_or_path, filename=filename, **kwargs)
return llamacpp(
self.model_name_or_path,
filename=filename,
tokenizer=self._llama_tokenizer,
**kwargs
)
2 changes: 2 additions & 0 deletions blendsql/models/local/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs):
) from None
import transformers

transformers.logging.set_verbosity_error()

super().__init__(
model_name_or_path=model_name_or_path,
requires_config=False,
Expand Down
2 changes: 1 addition & 1 deletion blendsql/nl_to_blendsql/nl_to_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import logging

from ..utils import logger
from .._logger import logger
from ..ingredients import Ingredient, IngredientException
from ..models import Model, OllamaLLM
from ..db import Database, double_quote_escape
Expand Down
7 changes: 3 additions & 4 deletions blendsql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import re
from tabulate import tabulate
from functools import partial
import logging

from ._constants import HF_REPO_ID

logging.basicConfig()
logger = logging.getLogger("blendsql")
tabulate = partial(tabulate, headers="keys", showindex="never", tablefmt="orgtbl")
tabulate = partial(
tabulate, headers="keys", showindex="never", tablefmt="simple_outline"
)
newline_dedent = lambda x: "\n".join([m.lstrip() for m in x.split("\n")])


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def find_version(*file_paths):
"attrs",
"tqdm",
"colorama",
"tabulate",
"tabulate>=0.9.0",
"typeguard",
"rapidfuzz",
],
Expand Down

0 comments on commit fead7c2

Please sign in to comment.