Skip to content

Commit

Permalink
cache size limit implement
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Sep 3, 2024
1 parent 0e15ec0 commit 2365951
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 66 deletions.
67 changes: 67 additions & 0 deletions docs/en/quickstart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

## Quick Start
Run the command in the console:
```bash
python -m olah.server
```

Then set the Environment Variable `HF_ENDPOINT` to the mirror site (Here is http://localhost:8090).

Linux:
```bash
export HF_ENDPOINT=http://localhost:8090
```

Windows Powershell:
```bash
$env:HF_ENDPOINT = "http://localhost:8090"
```

Starting from now on, all download operations in the HuggingFace library will be proxied through this mirror site.
```bash
pip install -U huggingface_hub
```

```python
from huggingface_hub import snapshot_download

snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model',
local_dir='./model_dir', resume_download=True,
max_workers=8)
```

Or you can download models and datasets by using huggingface cli.

Download GPT2:
```bash
huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2
```

Download WikiText:
```bash
huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext
```

You can check the path `./repos`, in which olah stores all cached datasets and models.

## Start the server
Run the command in the console:
```bash
python -m olah.server
```

Or you can specify the host address and listening port:
```bash
python -m olah.server --host localhost --port 8090
```
**Note: Please change --mirror-netloc and --mirror-lfs-netloc to the actual URLs of the mirror sites when modifying the host and port.**
```bash
python -m olah.server --host 192.168.1.100 --port 8090 --mirror-netloc 192.168.1.100:8090
```

The default mirror cache path is `./repos`, you can change it by `--repos-path` parameter:
```bash
python -m olah.server --host localhost --port 8090 --repos-path ./hf_mirrors
```

**Note that the cached data between different versions cannot be migrated. Please delete the cache folder before upgrading to the latest version of Olah.**
45 changes: 45 additions & 0 deletions docs/zh/quickstart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
## 快速开始
在控制台运行以下命令:
```bash
python -m olah.server
```

然后将环境变量`HF_ENDPOINT`设置为镜像站点(这里是http://localhost:8090/)。

Linux:
```bash
export HF_ENDPOINT=http://localhost:8090
```

Windows Powershell:
```bash
$env:HF_ENDPOINT = "http://localhost:8090"
```

从现在开始,HuggingFace库中的所有下载操作都将通过此镜像站点代理进行。
```bash
pip install -U huggingface_hub
```

```python
from huggingface_hub import snapshot_download

snapshot_download(repo_id='Qwen/Qwen-7B', repo_type='model',
local_dir='./model_dir', resume_download=True,
max_workers=8)

```

或者你也可以使用huggingface cli直接下载模型和数据集.

下载GPT2:
```bash
huggingface-cli download --resume-download openai-community/gpt2 --local-dir gpt2
```

下载WikiText:
```bash
huggingface-cli download --repo-type dataset --resume-download Salesforce/wikitext --local-dir wikitext
```

您可以查看路径`./repos`,其中存储了所有数据集和模型的缓存。
8 changes: 7 additions & 1 deletion olah/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

from typing import List, Optional, Union
from typing import List, Literal, Optional, Union
import toml
import re
import fnmatch

from olah.utils.disk_utils import convert_to_bytes

DEFAULT_PROXY_RULES = [
{"repo": "*", "allow": True, "use_re": False},
{"repo": "*/*", "allow": True, "use_re": False},
Expand Down Expand Up @@ -83,6 +85,8 @@ def __init__(self, path: Optional[str] = None) -> None:
self.ssl_key = None
self.ssl_cert = None
self.repos_path = "./repos"
self.cache_size_limit: Optional[int] = None
self.cache_clean_strategy: Literal["LRU", "FIFO", "LARGE_FIRST"] = "LRU"

self.hf_scheme: str = "https"
self.hf_netloc: str = "huggingface.co"
Expand Down Expand Up @@ -140,6 +144,8 @@ def read_toml(self, path: str) -> None:
self.ssl_key = self.empty_str(basic.get("ssl-key", self.ssl_key))
self.ssl_cert = self.empty_str(basic.get("ssl-cert", self.ssl_cert))
self.repos_path = basic.get("repos-path", self.repos_path)
self.cache_size_limit = convert_to_bytes(basic.get("cache-size-limit", self.cache_size_limit))
self.cache_clean_strategy = basic.get("cache-clean-strategy", self.cache_clean_strategy)

self.hf_scheme = basic.get("hf-scheme", self.hf_scheme)
self.hf_netloc = basic.get("hf-netloc", self.hf_netloc)
Expand Down
30 changes: 22 additions & 8 deletions olah/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,39 @@

from olah.utils.olah_utils import get_olah_path



db_path = os.path.join(get_olah_path(), "database.db")
db = SqliteDatabase(db_path)


class BaseModel(Model):
class Meta:
database = db


class User(BaseModel):
username = CharField(unique=True)

class Token(BaseModel):
token = CharField(unique=True)
first_dt = DateTimeField()
last_dt = DateTimeField()

class DownloadLogs(BaseModel):
id = CharField(unique=True)
org = CharField()
repo = CharField()
path = CharField()
range_start = BigIntegerField()
range_end = BigIntegerField()
datetime = DateTimeField()
user = CharField()
token = CharField()

class FileLevelLRU(BaseModel):
org = CharField()
repo = CharField()
path = CharField()
datetime = DateTimeField(default=datetime.datetime.now)

db.connect()
db.create_tables([
User,
Token,
UserToken,
DownloadLogs,
FileLevelLRU,
])
91 changes: 50 additions & 41 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from olah.cache.olah_cache import OlahCache
from olah.proxy.pathsinfo import pathsinfo_generator
from olah.utils.cache_utils import _read_cache_request, _write_cache_request
from olah.utils.disk_utils import touch_file_access_time
from olah.utils.url_utils import (
RemoteInfo,
add_query_param,
Expand Down Expand Up @@ -244,6 +245,7 @@ async def _get_file_range_from_remote(
url=remote_info.url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
follow_redirects=True,
) as response:
async for raw_chunk in response.aiter_raw():
if not raw_chunk:
Expand Down Expand Up @@ -307,6 +309,7 @@ async def _file_chunk_get(
app,
save_path: str,
head_path: str,
client: httpx.AsyncClient,
method: str,
url: str,
headers: Dict[str, str],
Expand All @@ -319,6 +322,10 @@ async def _file_chunk_get(
else:
cache_file = OlahCache.create(save_path)
cache_file.resize(file_size=file_size)

# Refresh access time
touch_file_access_time(save_path)

try:
start_pos, end_pos = parse_range_params(
headers.get("range", f"bytes={0}-{file_size-1}"), file_size
Expand All @@ -328,7 +335,6 @@ async def _file_chunk_get(
ranges_and_cache_list = get_contiguous_ranges(cache_file, start_pos, end_pos)
# Stream ranges
for (range_start_pos, range_end_pos), is_remote in ranges_and_cache_list:
client = httpx.AsyncClient()
if is_remote:
generator = _get_file_range_from_remote(
client,
Expand Down Expand Up @@ -394,7 +400,6 @@ async def _file_chunk_get(
raise Exception(
f"The size of cached range ({range_end_pos - range_start_pos}) is different from sent size ({cur_pos - range_start_pos})."
)
await client.aclose()
finally:
cache_file.close()

Expand All @@ -403,24 +408,24 @@ async def _file_chunk_head(
app,
save_path: str,
head_path: str,
client: httpx.AsyncClient,
method: str,
url: str,
headers: Dict[str, str],
allow_cache: bool,
file_size: int,
):
if not app.app_settings.config.offline:
async with httpx.AsyncClient() as client:
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
yield raw_chunk
async with client.stream(
method=method,
url=url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
async for raw_chunk in response.aiter_raw():
if not raw_chunk:
continue
yield raw_chunk
else:
yield b""

Expand Down Expand Up @@ -518,32 +523,36 @@ async def _file_realtime_stream(
response_headers["etag"] = f'"{content_hash[:32]}-10"'
yield 200
yield response_headers
if method.lower() == "get":
async for each_chunk in _file_chunk_get(
app=app,
save_path=save_path,
head_path=head_path,
method=method,
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
file_size=file_size,
):
yield each_chunk
elif method.lower() == "head":
async for each_chunk in _file_chunk_head(
app=app,
save_path=save_path,
head_path=head_path,
method=method,
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
file_size=0,
):
yield each_chunk
else:
raise Exception(f"Unsupported method: {method}")

async with httpx.AsyncClient() as client:
if method.lower() == "get":
async for each_chunk in _file_chunk_get(
app=app,
save_path=save_path,
head_path=head_path,
client=client,
method=method,
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
file_size=file_size,
):
yield each_chunk
elif method.lower() == "head":
async for each_chunk in _file_chunk_head(
app=app,
save_path=save_path,
head_path=head_path,
client=client,
method=method,
url=hf_url,
headers=request_headers,
allow_cache=allow_cache,
file_size=0,
):
yield each_chunk
else:
raise Exception(f"Unsupported method: {method}")


async def file_get_generator(
Expand All @@ -558,7 +567,7 @@ async def file_get_generator(
):
org_repo = get_org_repo(org, repo)
# save
repos_path = app.app_settings.repos_path
repos_path = app.app_settings.config.repos_path
head_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/resolve/{commit}/{file_path}"
)
Expand Down Expand Up @@ -612,7 +621,7 @@ async def cdn_file_get_generator(

org_repo = get_org_repo(org, repo)
# save
repos_path = app.app_settings.repos_path
repos_path = app.app_settings.config.repos_path
head_path = os.path.join(
repos_path, f"heads/{repo_type}/{org}/{repo}/cdn/{file_hash}"
)
Expand Down
4 changes: 2 additions & 2 deletions olah/proxy/lfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def lfs_head_generator(
app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request
):
# save
repos_path = app.app_settings.repos_path
repos_path = app.app_settings.config.repos_path
head_path = os.path.join(
repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}"
)
Expand Down Expand Up @@ -47,7 +47,7 @@ async def lfs_get_generator(
app, dir1: str, dir2: str, hash_repo: str, hash_file: str, request: Request
):
# save
repos_path = app.app_settings.repos_path
repos_path = app.app_settings.config.repos_path
head_path = os.path.join(
repos_path, f"lfs/heads/{dir1}/{dir2}/{hash_repo}/{hash_file}"
)
Expand Down
2 changes: 1 addition & 1 deletion olah/proxy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def meta_generator(

# save
method = request.method.lower()
repos_path = app.app_settings.repos_path
repos_path = app.app_settings.config.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}"
)
Expand Down
2 changes: 1 addition & 1 deletion olah/proxy/pathsinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def pathsinfo_generator(
):
headers = {}
# save
repos_path = app.app_settings.repos_path
repos_path = app.app_settings.config.repos_path

final_content = []
for path in paths:
Expand Down
Loading

0 comments on commit 2365951

Please sign in to comment.