From 0bcf850c54f5131bd888dd7bacdd8f269480097e Mon Sep 17 00:00:00 2001 From: Bibo Hao Date: Wed, 15 May 2024 21:30:21 +0800 Subject: [PATCH] Update APIs service (#17) * new feats * fix requests error * Update README.md --- README.md | 4 +- demo/app_common/ainlp/__init__.py | 0 demo/app_common/ainlp/model_bert.py | 86 +++++++++ demo/app_common/ainlp/test-gpu-async.py | 0 demo/app_common/api/api_multipart.py | 14 ++ demo/app_common/debug.py | 1 + src/aloha/config/paths.py | 8 +- src/aloha/encrypt/vault/cyberark.py | 3 +- src/aloha/service/api/v0.py | 6 +- src/aloha/service/http/base_api_handler.py | 14 +- src/aloha/service/http/files.py | 33 ++++ src/aloha/service/streamer/redis.py | 8 +- src/aloha/times/timeout_async.py | 210 +++++++++++++++++++-- 13 files changed, 357 insertions(+), 30 deletions(-) create mode 100644 demo/app_common/ainlp/__init__.py create mode 100644 demo/app_common/ainlp/model_bert.py create mode 100644 demo/app_common/ainlp/test-gpu-async.py create mode 100644 demo/app_common/api/api_multipart.py create mode 100644 src/aloha/service/http/files.py diff --git a/README.md b/README.md index 451f9a9..511972c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Aloha! [![License](https://img.shields.io/github/license/QPod/aloha)](https://github.com/QPod/aloha/blob/main/LICENSE) -[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/QPod/aloha/build)](https://github.com/QPod/aloha/actions) +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/QPod/aloha-python/pip.yml?branch=main)](https://github.com/QPod/aloha-python/actions) [![Join the Gitter Chat](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/QPod/) [![PyPI version](https://img.shields.io/pypi/v/aloha)](https://pypi.python.org/pypi/aloha/) [![PyPI Downloads](https://img.shields.io/pypi/dm/aloha)](https://pepy.tech/badge/aloha/) @@ -21,6 +21,6 @@ Please generously STAR★ our project or donate to us! [![GitHub Starts](https: ## Getting started -```py +```shell pip install aloha[all] ``` diff --git a/demo/app_common/ainlp/__init__.py b/demo/app_common/ainlp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/app_common/ainlp/model_bert.py b/demo/app_common/ainlp/model_bert.py new file mode 100644 index 0000000..477a895 --- /dev/null +++ b/demo/app_common/ainlp/model_bert.py @@ -0,0 +1,86 @@ +from typing import List + +import torch +from transformers import AutoTokenizer, AutoModel + +from aloha.service.streamer import ManagedModel + +SEED = 0 +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) + + +class TextUnmaskModel: + def __init__(self, max_sent_len=16, model_path="bert-base-uncased"): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + self.transformer = AutoModel.from_pretrained(self.model_path) + self.transformer.eval() + self.transformer.to(device="cuda") + self.max_sent_len = max_sent_len + + def predict(self, batch: List[str]) -> List[str]: + """predict masked word""" + batch_inputs = [] + masked_indexes = [] + + for text in batch: + tokenized_text = self.tokenizer.tokenize(text) + if len(tokenized_text) > self.max_sent_len - 2: + tokenized_text = tokenized_text[: self.max_sent_len - 2] + + tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]'] + tokenized_text += ['[PAD]'] * (self.max_sent_len - len(tokenized_text)) + + indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) + batch_inputs.append(indexed_tokens) + masked_indexes.append(tokenized_text.index('[MASK]')) + + tokens_tensor = torch.tensor(batch_inputs).to("cuda") + + with torch.no_grad(): + # prediction_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + prediction_scores = self.transformer(tokens_tensor)[0] + + batch_outputs = [] + for i in range(len(batch_inputs)): + predicted_index = torch.argmax(prediction_scores[i, masked_indexes[i]]).item() + predicted_token = self.tokenizer.convert_ids_to_tokens(predicted_index) + batch_outputs.append(predicted_token) + + return batch_outputs + + +class ManagedBertModel(ManagedModel): + def init_model(self): + self.model = TextUnmaskModel() + + def predict(self, batch): + return self.model.predict(batch) + + +def test_simple(): + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + model = AutoModel.from_pretrained("bert-base-uncased") + inputs = tokenizer("Hello! My name is [MASK]!", return_tensors="pt") + outputs = model(**inputs) + print(outputs) + + predicted_index = torch.argmax(outputs[1]).item() + predicted_token = tokenizer.convert_ids_to_tokens(predicted_index) + print(predicted_token) + + +def test_batch(): + batch_text = [ + "twinkle twinkle [MASK] star.", + "Happy birthday to [MASK].", + 'the answer to life, the [MASK], and everything.' + ] + model = TextUnmaskModel() + outputs = model.predict(batch_text) + print(outputs) + + +if __name__ == "__main__": + test_simple() diff --git a/demo/app_common/ainlp/test-gpu-async.py b/demo/app_common/ainlp/test-gpu-async.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/app_common/api/api_multipart.py b/demo/app_common/api/api_multipart.py new file mode 100644 index 0000000..595bea1 --- /dev/null +++ b/demo/app_common/api/api_multipart.py @@ -0,0 +1,14 @@ +from aloha.logger import LOG +from aloha.service.api.v0 import APIHandler + + +class MultipartHandler(APIHandler): + def response(self, params=None, *args, **kwargs): + LOG.debug(params) + return params + + +default_handlers = [ + # internal API: QueryDB Postgres with sql directly + (r"/api_internal/multipart", MultipartHandler), +] diff --git a/demo/app_common/debug.py b/demo/app_common/debug.py index aef06ae..59f4b68 100644 --- a/demo/app_common/debug.py +++ b/demo/app_common/debug.py @@ -6,6 +6,7 @@ def main(): modules_to_load = [ "app_common.api.api_common_sys_info", "app_common.api.api_common_query_postgres", + "app_common.api.api_multipart", ] if 'service' not in SETTINGS.config: diff --git a/src/aloha/config/paths.py b/src/aloha/config/paths.py index 1e40c54..8d1cf6a 100644 --- a/src/aloha/config/paths.py +++ b/src/aloha/config/paths.py @@ -48,15 +48,19 @@ def get_config_files() -> list: files = files_config.split(',') ret = [] + msgs = [] for f in files: file = get_config_dir(f) if not os.path.exists(file): - warnings.warn('Expecting config file [%s] but it does not exists!' % file) + msgs.append('Expecting config file [%s] but it does not exists!' % file) else: print(' ---> Loading config file [%s]' % file) ret.append(os.path.expandvars(f)) if len(ret) == 0: - warnings.warn('No config files set properly, EMPTY config will be used!') + msgs.append('No config files set properly, EMPTY config will be used!') + + if len(msgs) > 0: + warnings.warn('\n'.join(msgs)) return ret diff --git a/src/aloha/encrypt/vault/cyberark.py b/src/aloha/encrypt/vault/cyberark.py index 33273de..5d3345d 100644 --- a/src/aloha/encrypt/vault/cyberark.py +++ b/src/aloha/encrypt/vault/cyberark.py @@ -11,7 +11,8 @@ from ...logger import LOG requests.packages.urllib3.disable_warnings(InsecureRequestWarning) -requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += ':HIGHT:!DH:!aNULL' +if hasattr(requests.packages.urllib3.util.ssl_, 'DEFAULT_CIPHERS'): + requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += ':HIGHT:!DH:!aNULL' class CyberArkVault(BaseVault, AesEncryptor): diff --git a/src/aloha/service/api/v0.py b/src/aloha/service/api/v0.py index b5ede1b..b239cfc 100644 --- a/src/aloha/service/api/v0.py +++ b/src/aloha/service/api/v0.py @@ -13,8 +13,10 @@ class APIHandler(AbstractApiHandler, ABC): } async def post(self, *args, **kwargs): - body_arguments = self.request_body - kwargs.update(body_arguments) + req_body = self.request_body + + if req_body is not None: # body_arguments + kwargs.update(req_body) resp = dict(code=5200, message=['success']) try: diff --git a/src/aloha/service/http/base_api_handler.py b/src/aloha/service/http/base_api_handler.py index 42ee7bf..0cbd406 100644 --- a/src/aloha/service/http/base_api_handler.py +++ b/src/aloha/service/http/base_api_handler.py @@ -51,7 +51,7 @@ def request_body(self) -> dict: body_arguments: dict = Optional[None] if content_type.startswith('multipart/form-data'): # only parse files when 'Content-Type' starts with 'multipart/form-data' - body_arguments = self.request.body_arguments + body_arguments = self.request_param # self.request.body_arguments else: try: body = self.request.body.decode('utf-8') @@ -62,8 +62,16 @@ def request_body(self) -> dict: @property def request_param(self) -> dict: - url_arguments: dict = {k: v[0].decode('utf-8') for k, v in self.request.arguments.items()} - return url_arguments + ret: dict = {} + for k, v in self.request.arguments.items(): + val = v[0].decode('utf-8') + try: + value = json.loads(val) + except json.JSONDecodeError: + value = val + ret[k] = value + + return ret class DefaultHandler404(AbstractApiHandler): diff --git a/src/aloha/service/http/files.py b/src/aloha/service/http/files.py new file mode 100644 index 0000000..79fe60c --- /dev/null +++ b/src/aloha/service/http/files.py @@ -0,0 +1,33 @@ +import time + +import requests + +from ...logger import LOG + + +def iter_over_request_files(request, url_files): + for file_key, files in request.files.items(): # iter over files uploaded by multipart + for f in files: + file_name, content_type = f["filename"], f["content_type"] + body = f.get('body', b"") + LOG.info(f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}") + yield file_key, file_name, content_type, body + + for file_key, list_url in {'url_files': url_files or []}.items(): # iter over files specified by `url_files` + for url in sorted(set(list_url)): + try: + t_start = time.time() + resp = requests.get(url, stream=True) # download the file from given url + if resp.status_code == 200: + body = resp.content + content_type = resp.headers.get("Content-Type", "UNKNOWN") + else: + raise RuntimeError("Failed to download file after %s seconds with code=%s from URL %s" % ( + time.time() - t_start, resp.status_code, url + )) + del resp + except Exception as e: + raise e + t_cost = time.time() - t_start + LOG.info(f"File {url} has content type {content_type} and length bytes={len(body)}, downloaded in {t_cost} seconds") + yield 'url_files', url, content_type, body diff --git a/src/aloha/service/streamer/redis.py b/src/aloha/service/streamer/redis.py index 7b97137..3f32632 100644 --- a/src/aloha/service/streamer/redis.py +++ b/src/aloha/service/streamer/redis.py @@ -5,9 +5,13 @@ import threading import time -from redis import Redis - from .base import BaseStreamer, BaseWorker, TIMEOUT, TIME_SLEEP, logger +from ...logger import LOG + +try: + from redis import Redis +except ImportError: + LOG.warn('redis not installed, service.streamer.RedisStreamer will no be available!') class RedisWorker(BaseWorker): diff --git a/src/aloha/times/timeout_async.py b/src/aloha/times/timeout_async.py index 194845d..8ca7f7b 100644 --- a/src/aloha/times/timeout_async.py +++ b/src/aloha/times/timeout_async.py @@ -1,25 +1,199 @@ +""" Refer to: https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +""" import asyncio -from functools import wraps, partial +import enum +import warnings +from types import TracebackType +from typing import Optional, Type, final -__all__ = ('timeout',) +__all__ = ("timeout", "timeout_at", "Timeout") -def aioify(func): - @wraps(func) - async def run(*args, loop=None, executor=None, **kwargs): - if loop is None: - loop = asyncio.get_event_loop() - p_func = partial(func, *args, **kwargs) - return await loop.run_in_executor(executor, p_func) +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) - return run +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + deadline argument points on the time in the same clock system + as loop.time(). + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) -def timeout(func, timeout=1): - async def f(): - try: - result = await asyncio.wait_for(aioify(func), timeout=timeout) - return result - except asyncio.TimeoutError: - raise TimeoutError() - return f + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + The delay can be negative. + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + deadline argument points on the time in the same clock system + as loop.time(). + If new deadline is in the past the timeout is raised immediately. + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "asyncio.Task[None]") -> None: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None