diff --git a/pyproject.toml b/pyproject.toml index ba86bc2..586c36a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,26 +64,24 @@ python = ["3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = false # Normally the linting env can be detached, but mypy doesn't install all the stubs we need dependencies = [ - "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", + "pylint", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/banks}" -style = [ +check = [ + "ruff format --check {args}", "ruff check {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff check --fix {args:.}", - "style", ] +lint = "pylint {args:src/banks}" +typing = "mypy --install-types --non-interactive {args:src/banks}" all = [ - "style", + "check", "typing", + "lint", ] +fmt = "ruff format {args}" [tool.hatch.build.targets.wheel] only-include = ["src/banks", "src/templates"] @@ -92,11 +90,6 @@ only-include = ["src/banks", "src/templates"] "src" = "" "templates" = "banks/templates" -[tool.black] -target-version = ["py39"] -line-length = 120 -skip-string-normalization = true - [tool.ruff] target-version = "py39" line-length = 120 @@ -157,6 +150,7 @@ ban-relative-imports = "parents" # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] + [tool.coverage.run] source_pkgs = ["banks", "tests"] branch = true @@ -181,4 +175,13 @@ module = [ "simplemma.*", "litellm.*", ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + +[tool.pylint] +disable = [ + "line-too-long", + "too-few-public-methods", + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", +] \ No newline at end of file diff --git a/src/banks/config.py b/src/banks/config.py index 044e199..c7594dc 100644 --- a/src/banks/config.py +++ b/src/banks/config.py @@ -23,18 +23,16 @@ def __getattribute__(self, name: str) -> Any: # Env var takes precedence prefix = super().__getattribute__("_env_var_prefix") - value = os.environ.get(f"{prefix}{name}") - if value is None: + read_value = os.environ.get(f"{prefix}{name}") + if read_value is None: return original_value # Convert string from env var to the actual type t = super().__getattribute__("__annotations__")[name] if t == bool: - value = strtobool(value) - else: - value = t(value) + return strtobool(read_value) - return value + return t(read_value) config = _BanksConfig() diff --git a/src/banks/env.py b/src/banks/env.py index fd06c27..0112440 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -7,16 +7,17 @@ from .filters import lemmatize -def _add_extensions(env): +def _add_extensions(_env): """ We lazily add extensions so that we can use the env in the extensions themselves if needed. For example, we use banks to manage the system prompt in `GenerateExtension` """ - from .extensions import GenerateExtension, HFInferenceEndpointsExtension + from .extensions.generate import GenerateExtension # pylint: disable=import-outside-toplevel + from .extensions.inference_endpoint import HFInferenceEndpointsExtension # pylint: disable=import-outside-toplevel - env.add_extension(GenerateExtension) - env.add_extension(HFInferenceEndpointsExtension) + _env.add_extension(GenerateExtension) + _env.add_extension(HFInferenceEndpointsExtension) # Init the Jinja env diff --git a/src/banks/extensions/generate.py b/src/banks/extensions/generate.py index 6c44b52..9cfb847 100644 --- a/src/banks/extensions/generate.py +++ b/src/banks/extensions/generate.py @@ -14,7 +14,8 @@ SYSTEM_PROMPT = Prompt("{{canary_word}} You are a helpful assistant.") -def generate(model_name: str): # noqa # This function exists for documentation purpose. +# This function exists for documentation purpose. +def generate(model_name: str): # noqa # pylint: disable=W0613 """ `generate` can be used to call the LiteLLM API passing the tag text as a prompt and get back some content. @@ -48,7 +49,9 @@ def parse(self, parser): args.append(nodes.Const(None)) if parser.environment.is_async: - return nodes.Output([self.call_method("_agenerate", args)]).set_lineno(lineno) + return nodes.Output([self.call_method("_agenerate", args)]).set_lineno( + lineno + ) return nodes.Output([self.call_method("_generate", args)]).set_lineno(lineno) def _generate(self, text, model_name=DEFAULT_MODEL): @@ -61,7 +64,9 @@ def _generate(self, text, model_name=DEFAULT_MODEL): {"role": "system", "content": SYSTEM_PROMPT.text()}, {"role": "user", "content": text}, ] - response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages)) + response: ModelResponse = cast( + ModelResponse, completion(model=model_name, messages=messages) + ) return self._get_content(response) async def _agenerate(self, text, model_name=DEFAULT_MODEL): @@ -74,7 +79,9 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL): {"role": "system", "content": SYSTEM_PROMPT.text()}, {"role": "user", "content": text}, ] - response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages)) + response: ModelResponse = cast( + ModelResponse, await acompletion(model=model_name, messages=messages) + ) return self._get_content(response) def _get_content(self, response: ModelResponse) -> str: diff --git a/src/banks/filters/lemmatize.py b/src/banks/filters/lemmatize.py index 707246d..277c210 100644 --- a/src/banks/filters/lemmatize.py +++ b/src/banks/filters/lemmatize.py @@ -6,9 +6,9 @@ try: from simplemma import text_lemmatizer - simplemma_avail = True + SIMPLEMMA_AVAIL = True except ImportError: - simplemma_avail = False + SIMPLEMMA_AVAIL = False def lemmatize(text: str) -> str: @@ -25,7 +25,7 @@ def lemmatize(text: str) -> str: Note: Simplemma must be manually installed to use this filter """ - if not simplemma_avail: + if not SIMPLEMMA_AVAIL: err_msg = "simplemma is not available, please install it with 'pip install simplemma'" raise MissingDependencyError(err_msg) diff --git a/src/banks/registries/file.py b/src/banks/registries/file.py index 7d7c614..7fac0f2 100644 --- a/src/banks/registries/file.py +++ b/src/banks/registries/file.py @@ -40,7 +40,7 @@ def _make_id(name: str, version: str | None): return name def save(self) -> None: - with open(self._index_fpath, "w") as f: + with open(self._index_fpath, "w", encoding="locale") as f: f.write(self._index.model_dump_json()) def get(self, name: str, version: str | None = None) -> "Prompt": diff --git a/src/banks/utils.py b/src/banks/utils.py index 48f17ef..eb29581 100644 --- a/src/banks/utils.py +++ b/src/banks/utils.py @@ -12,11 +12,12 @@ def strtobool(val: str) -> bool: val = val.lower() if val in ("y", "yes", "t", "true", "on", "1"): return True - elif val in ("n", "no", "f", "false", "off", "0"): + + if val in ("n", "no", "f", "false", "off", "0"): return False - else: - msg = f"invalid truth value {val}" - raise ValueError(msg) + + msg = f"invalid truth value {val}" + raise ValueError(msg) def generate_canary_word(prefix: str = "BANKS[", suffix: str = "]", token_length: int = 8) -> str: