diff --git a/pyproject.toml b/pyproject.toml index 5341974..832b958 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,6 @@ omit = [ "tests/*", "src/banks/extensions/docs.py", # deprecated modules, to be removed - "src/banks/extensions/inference_endpoint.py", "src/banks/extensions/generate.py", ] diff --git a/src/banks/env.py b/src/banks/env.py index 7d495a8..f383f62 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -16,12 +16,10 @@ def _add_extensions(_env): from .extensions.chat import ChatExtension # pylint: disable=import-outside-toplevel from .extensions.completion import CompletionExtension # pylint: disable=import-outside-toplevel from .extensions.generate import GenerateExtension # pylint: disable=import-outside-toplevel - from .extensions.inference_endpoint import HFInferenceEndpointsExtension # pylint: disable=import-outside-toplevel _env.add_extension(ChatExtension) _env.add_extension(CompletionExtension) _env.add_extension(GenerateExtension) - _env.add_extension(HFInferenceEndpointsExtension) # Init the Jinja env diff --git a/src/banks/extensions/__init__.py b/src/banks/extensions/__init__.py index 6046ecd..2f5bf74 100644 --- a/src/banks/extensions/__init__.py +++ b/src/banks/extensions/__init__.py @@ -2,6 +2,5 @@ # # SPDX-License-Identifier: MIT from banks.extensions.generate import GenerateExtension -from banks.extensions.inference_endpoint import HFInferenceEndpointsExtension -__all__ = ("GenerateExtension", "HFInferenceEndpointsExtension") +__all__ = ("GenerateExtension",) diff --git a/src/banks/extensions/inference_endpoint.py b/src/banks/extensions/inference_endpoint.py deleted file mode 100644 index 4eb80e8..0000000 --- a/src/banks/extensions/inference_endpoint.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi -# -# SPDX-License-Identifier: MIT -import html -import os - -import requests -from deprecated import deprecated -from jinja2 import nodes -from jinja2.ext import Extension - - -@deprecated(version="1.3.0", reason="This extension is deprecated, use {% completion %} instead.") -class HFInferenceEndpointsExtension(Extension): - """ - `inference_endpoint` can be used to call the Hugging Face Inference Endpoint API - passing a prompt to get back some content. - - Deprecated: - This extension is deprecated, use `{% completion %}` instead. - - Example: - ```jinja - {% inference_endpoint "write a tweet with positive sentiment", "https://foo.aws.endpoints.huggingface.cloud" %} - Life is beautiful, full of opportunities & positivity - ``` - """ - - # a set of names that trigger the extension. - tags = {"inference_endpoint"} # noqa - - def parse(self, parser): - # We get the line number of the first token so that we can give - # that line number to the nodes we create by hand. - lineno = next(parser.stream).lineno - - # The args passed to the extension: - # - the prompt text used to generate new text - args = [parser.parse_expression()] - # - second param after the comma, the inference endpoint URL - parser.stream.skip_if("comma") - args.append(parser.parse_expression()) - - return nodes.Output([self.call_method("_call_endpoint", args)]).set_lineno(lineno) - - def _call_endpoint(self, text, endpoint): - """ - Helper callback. - """ - access_token = os.environ.get("HF_ACCESS_TOKEN") - response = requests.post( - endpoint, json={"inputs": text}, headers={"Authorization": f"Bearer {access_token}"}, timeout=30 - ) - response_body = response.json() - - if response_body: - return html.unescape(response_body[0].get("generated_text", "")) - return ""