Skip to content

Commit

Permalink
more linting
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Jun 25, 2024
1 parent a3bdf8f commit bf6092e
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 38 deletions.
35 changes: 19 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,24 @@ python = ["3.10", "3.11", "3.12"]
[tool.hatch.envs.lint]
detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need
dependencies = [
"black>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
"pylint",
]

[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/banks}"
style = [
check = [
"ruff format --check {args}",
"ruff check {args:.}",
"black --check --diff {args:.}",
]
fmt = [
"black {args:.}",
"ruff check --fix {args:.}",
"style",
]
lint = "pylint {args:src/banks}"
typing = "mypy --install-types --non-interactive {args:src/banks}"
all = [
"style",
"check",
"typing",
"lint",
]
fmt = "ruff format {args}"

[tool.hatch.build.targets.wheel]
only-include = ["src/banks", "src/templates"]
Expand All @@ -92,11 +90,6 @@ only-include = ["src/banks", "src/templates"]
"src" = ""
"templates" = "banks/templates"

[tool.black]
target-version = ["py39"]
line-length = 120
skip-string-normalization = true

[tool.ruff]
target-version = "py39"
line-length = 120
Expand Down Expand Up @@ -157,6 +150,7 @@ ban-relative-imports = "parents"
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]


[tool.coverage.run]
source_pkgs = ["banks", "tests"]
branch = true
Expand All @@ -181,4 +175,13 @@ module = [
"simplemma.*",
"litellm.*",
]
ignore_missing_imports = true
ignore_missing_imports = true

[tool.pylint]
disable = [
"line-too-long",
"too-few-public-methods",
"missing-module-docstring",
"missing-class-docstring",
"missing-function-docstring",
]
10 changes: 4 additions & 6 deletions src/banks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@ def __getattribute__(self, name: str) -> Any:

# Env var takes precedence
prefix = super().__getattribute__("_env_var_prefix")
value = os.environ.get(f"{prefix}{name}")
if value is None:
read_value = os.environ.get(f"{prefix}{name}")
if read_value is None:
return original_value

# Convert string from env var to the actual type
t = super().__getattribute__("__annotations__")[name]
if t == bool:
value = strtobool(value)
else:
value = t(value)
return strtobool(read_value)

return value
return t(read_value)


config = _BanksConfig()
9 changes: 5 additions & 4 deletions src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@
from .filters import lemmatize


def _add_extensions(env):
def _add_extensions(_env):
"""
We lazily add extensions so that we can use the env in the extensions themselves if needed.
For example, we use banks to manage the system prompt in `GenerateExtension`
"""
from .extensions import GenerateExtension, HFInferenceEndpointsExtension
from .extensions.generate import GenerateExtension # pylint: disable=import-outside-toplevel
from .extensions.inference_endpoint import HFInferenceEndpointsExtension # pylint: disable=import-outside-toplevel

env.add_extension(GenerateExtension)
env.add_extension(HFInferenceEndpointsExtension)
_env.add_extension(GenerateExtension)
_env.add_extension(HFInferenceEndpointsExtension)


# Init the Jinja env
Expand Down
15 changes: 11 additions & 4 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.")


def generate(model_name: str): # noqa # This function exists for documentation purpose.
# This function exists for documentation purpose.
def generate(model_name: str): # noqa # pylint: disable=W0613
"""
`generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content.
Expand Down Expand Up @@ -48,7 +49,9 @@ def parse(self, parser):
args.append(nodes.Const(None))

if parser.environment.is_async:
return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(lineno)
return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(
lineno
)
return nodes.Output([self.call_method("_generate", args)]).set_lineno(lineno)

def _generate(self, text, model_name=DEFAULT_MODEL):
Expand All @@ -61,7 +64,9 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
response: ModelResponse = cast(
ModelResponse, completion(model=model_name, messages=messages)
)
return self._get_content(response)

async def _agenerate(self, text, model_name=DEFAULT_MODEL):
Expand All @@ -74,7 +79,9 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL):
{"role": "system", "content": SYSTEM_PROMPT.text()},
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
response: ModelResponse = cast(
ModelResponse, await acompletion(model=model_name, messages=messages)
)
return self._get_content(response)

def _get_content(self, response: ModelResponse) -> str:
Expand Down
6 changes: 3 additions & 3 deletions src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
try:
from simplemma import text_lemmatizer

simplemma_avail = True
SIMPLEMMA_AVAIL = True
except ImportError:
simplemma_avail = False
SIMPLEMMA_AVAIL = False


def lemmatize(text: str) -> str:
Expand All @@ -25,7 +25,7 @@ def lemmatize(text: str) -> str:
Note:
Simplemma must be manually installed to use this filter
"""
if not simplemma_avail:
if not SIMPLEMMA_AVAIL:
err_msg = "simplemma is not available, please install it with 'pip install simplemma'"
raise MissingDependencyError(err_msg)

Expand Down
2 changes: 1 addition & 1 deletion src/banks/registries/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _make_id(name: str, version: str | None):
return name

def save(self) -> None:
with open(self._index_fpath, "w") as f:
with open(self._index_fpath, "w", encoding="locale") as f:
f.write(self._index.model_dump_json())

def get(self, name: str, version: str | None = None) -> "Prompt":
Expand Down
9 changes: 5 additions & 4 deletions src/banks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def strtobool(val: str) -> bool:
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):

if val in ("n", "no", "f", "false", "off", "0"):
return False
else:
msg = f"invalid truth value {val}"
raise ValueError(msg)

msg = f"invalid truth value {val}"
raise ValueError(msg)


def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str:
Expand Down

0 comments on commit bf6092e

Please sign in to comment.