From fe7c68b3ea5dcff53668ea29d4194cca407c8e63 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Thu, 15 Aug 2024 16:47:01 +0800 Subject: [PATCH] code format --- olah/database/__init__.py | 0 olah/database/models.py | 26 +++++++++++ olah/proxy/lfs.py | 2 +- olah/proxy/meta.py | 9 +++- olah/utils/logging.py | 96 --------------------------------------- olah/utils/olah_utils.py | 17 +++++++ requirements.txt | 1 + 7 files changed, 52 insertions(+), 99 deletions(-) create mode 100644 olah/database/__init__.py create mode 100644 olah/database/models.py create mode 100644 olah/utils/olah_utils.py diff --git a/olah/database/__init__.py b/olah/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/olah/database/models.py b/olah/database/models.py new file mode 100644 index 0000000..1d4d9fc --- /dev/null +++ b/olah/database/models.py @@ -0,0 +1,26 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import os +from peewee import * +import datetime + +from olah.utils.olah_utils import get_olah_path + + + +db_path = os.path.join(get_olah_path(), "database.db") +db = SqliteDatabase(db_path) + + +class BaseModel(Model): + class Meta: + database = db + + +class User(BaseModel): + username = CharField(unique=True) diff --git a/olah/proxy/lfs.py b/olah/proxy/lfs.py index a85dd26..69938b6 100644 --- a/olah/proxy/lfs.py +++ b/olah/proxy/lfs.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index 1fac6ab..9d210d0 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2024 XiaHan -# +# # Use of this source code is governed by an MIT-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. @@ -29,6 +29,7 @@ async def meta_cache_generator(app: FastAPI, save_path: str): break yield chunk + async def meta_proxy_cache( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], @@ -66,7 +67,10 @@ async def meta_proxy_cache( with open(save_path, "wb") as meta_file: meta_file.write(response.content) else: - raise Exception(f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}") + raise Exception( + f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}" + ) + async def meta_proxy_generator( app: FastAPI, @@ -98,6 +102,7 @@ async def meta_proxy_generator( with open(save_path, "wb") as f: f.write(bytes(content)) + async def meta_generator( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], diff --git a/olah/utils/logging.py b/olah/utils/logging.py index 8e9c16d..dac6bc0 100644 --- a/olah/utils/logging.py +++ b/olah/utils/logging.py @@ -21,13 +21,6 @@ from olah.constants import DEFAULT_LOGGER_DIR -server_error_msg = ( - "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" -) -moderation_msg = ( - "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." -) - handler = None @@ -143,95 +136,6 @@ def flush(self): self.linebuf = "" -def disable_torch_init(): - """ - Disable the redundant torch default initialization to accelerate model creation. - """ - import torch - - setattr(torch.nn.Linear, "reset_parameters", lambda self: None) - setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - - -def get_gpu_memory(max_gpus=None): - """Get available memory for each GPU.""" - gpu_memory = [] - num_gpus = ( - torch.cuda.device_count() - if max_gpus is None - else min(max_gpus, torch.cuda.device_count()) - ) - - for gpu_id in range(num_gpus): - with torch.cuda.device(gpu_id): - device = torch.cuda.current_device() - gpu_properties = torch.cuda.get_device_properties(device) - total_memory = gpu_properties.total_memory / (1024**3) - allocated_memory = torch.cuda.memory_allocated() / (1024**3) - available_memory = total_memory - allocated_memory - gpu_memory.append(available_memory) - return gpu_memory - - -def violates_moderation(text): - """ - Check whether the text violates OpenAI moderation API. - """ - url = "https://api.openai.com/v1/moderations" - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], - } - text = text.replace("\n", "") - data = "{" + '"input": ' + f'"{text}"' + "}" - data = data.encode("utf-8") - try: - ret = requests.post(url, headers=headers, data=data, timeout=5) - flagged = ret.json()["results"][0]["flagged"] - except requests.exceptions.RequestException as e: - flagged = False - except KeyError as e: - flagged = False - - return flagged - - -# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, -# Use this function to make sure it can be correctly loaded. -def clean_flant5_ckpt(ckpt_path): - index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") - index_json = json.load(open(index_file, "r")) - - weightmap = index_json["weight_map"] - - share_weight_file = weightmap["shared.weight"] - share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ - "shared.weight" - ] - - for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: - weight_file = weightmap[weight_name] - weight = torch.load(os.path.join(ckpt_path, weight_file)) - weight[weight_name] = share_weight - torch.save(weight, os.path.join(ckpt_path, weight_file)) - - -def pretty_print_semaphore(semaphore): - if semaphore is None: - return "None" - return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" - - -get_window_url_params_js = """ -function() { - const params = new URLSearchParams(window.location.search); - url_params = Object.fromEntries(params); - console.log("url_params", url_params); - return url_params; - } -""" - - def iter_over_async( async_gen: AsyncGenerator, event_loop: AbstractEventLoop ) -> Generator: diff --git a/olah/utils/olah_utils.py b/olah/utils/olah_utils.py new file mode 100644 index 0000000..dd686e4 --- /dev/null +++ b/olah/utils/olah_utils.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Copyright 2024 XiaHan +# +# Use of this source code is governed by an MIT-style +# license that can be found in the LICENSE file or at +# https://opensource.org/licenses/MIT. + +import platform +import os + + +def get_olah_path() -> str: + if platform.system() == "Windows": + olah_path = os.path.expanduser("~\\.olah") + else: + olah_path = os.path.expanduser("~/.olah") + return olah_path diff --git a/requirements.txt b/requirements.txt index b748e2e..f1ec968 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pytest==8.2.2 cachetools==5.4.0 PyYAML==6.0.1 tenacity==8.5.0 +peewee==3.17.6