From 0e15ec0e19067131906963d414bdb5037657b1ad Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Tue, 3 Sep 2024 15:10:59 +0800 Subject: [PATCH] content-encoding bug fix --- olah/proxy/files.py | 251 +++++++++++++++++++++++----------------- olah/proxy/meta.py | 2 +- olah/proxy/pathsinfo.py | 10 +- olah/server.py | 4 +- 4 files changed, 154 insertions(+), 113 deletions(-) diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 2060325..a5425f9 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -13,6 +13,7 @@ from requests.structures import CaseInsensitiveDict import httpx +import zlib from starlette.datastructures import URL from urllib.parse import urlparse, urljoin @@ -25,6 +26,7 @@ ORIGINAL_LOC, ) from olah.cache.olah_cache import OlahCache +from olah.proxy.pathsinfo import pathsinfo_generator from olah.utils.cache_utils import _read_cache_request, _write_cache_request from olah.utils.url_utils import ( RemoteInfo, @@ -233,22 +235,64 @@ async def _get_file_range_from_remote( start_pos: int, end_pos: int, ): - remote_info.headers["range"] = f"bytes={start_pos}-{end_pos - 1}" + headers = {} + headers["range"] = f"bytes={start_pos}-{end_pos - 1}" chunk_bytes = 0 + raw_data = b"" async with client.stream( method=remote_info.method, url=remote_info.url, - headers=remote_info.headers, + headers=headers, timeout=WORKER_API_TIMEOUT, - ) as response: + ) as response: async for raw_chunk in response.aiter_raw(): if not raw_chunk: continue - yield raw_chunk + if "content-encoding" in response.headers: + raw_data += raw_chunk + else: + yield raw_chunk chunk_bytes += len(raw_chunk) + # If result is compressed + if "content-encoding" in response.headers: + final_data = raw_data + algorithms = response.headers["content-encoding"].split(',') + for algo in algorithms: + algo = algo.strip().lower() + if algo == "gzip": + try: + final_data = zlib.decompress(raw_data, zlib.MAX_WBITS | 16) # 解压缩 + except Exception as e: + print(f"Error decompressing gzip data: {e}") + elif algo == "compress": + print(f"Unsupported decompression algorithm: {algo}") + elif algo == "deflate": + try: + final_data = zlib.decompress(raw_data) + except Exception as e: + print(f"Error decompressing deflate data: {e}") + elif algo == "br": + try: + import brotli + final_data = brotli.decompress(raw_data) + except Exception as e: + print(f"Error decompressing Brotli data: {e}") + elif algo == "zstd": + try: + import zstandard + final_data = zstandard.ZstdDecompressor().decompress(raw_data) + except Exception as e: + print(f"Error decompressing Zstandard data: {e}") + else: + print(f"Unsupported compression algorithm: {algo}") + chunk_bytes = len(final_data) + yield final_data if "content-length" in response.headers: - response_content_length = int(response.headers["content-length"]) + if "content-encoding" in response.headers: + response_content_length = len(final_data) + else: + response_content_length = int(response.headers["content-length"]) if end_pos - start_pos != response_content_length: raise Exception( f"The content of the response is incomplete. Expected-{end_pos - start_pos}. Accepted-{response_content_length}" @@ -263,7 +307,6 @@ async def _file_chunk_get( app, save_path: str, head_path: str, - client: httpx.AsyncClient, method: str, url: str, headers: Dict[str, str], @@ -285,6 +328,7 @@ async def _file_chunk_get( ranges_and_cache_list = get_contiguous_ranges(cache_file, start_pos, end_pos) # Stream ranges for (range_start_pos, range_end_pos), is_remote in ranges_and_cache_list: + client = httpx.AsyncClient() if is_remote: generator = _get_file_range_from_remote( client, @@ -350,6 +394,7 @@ async def _file_chunk_get( raise Exception( f"The size of cached range ({range_end_pos - range_start_pos}) is different from sent size ({cur_pos - range_start_pos})." ) + await client.aclose() finally: cache_file.close() @@ -358,7 +403,6 @@ async def _file_chunk_head( app, save_path: str, head_path: str, - client: httpx.AsyncClient, method: str, url: str, headers: Dict[str, str], @@ -366,22 +410,27 @@ async def _file_chunk_head( file_size: int, ): if not app.app_settings.config.offline: - async with client.stream( - method=method, - url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) as response: - async for raw_chunk in response.aiter_raw(): - if not raw_chunk: - continue - yield raw_chunk + async with httpx.AsyncClient() as client: + async with client.stream( + method=method, + url=url, + headers=headers, + timeout=WORKER_API_TIMEOUT, + ) as response: + async for raw_chunk in response.aiter_raw(): + if not raw_chunk: + continue + yield raw_chunk else: yield b"" async def _file_realtime_stream( app, + repo_type: Literal["models", "datasets", "spaces"], + org: str, + repo: str, + file_path: str, save_path: str, head_path: str, url: str, @@ -418,93 +467,83 @@ async def _file_realtime_stream( if "host" in request_headers: request_headers["host"] = urlparse(hf_url).netloc - async with httpx.AsyncClient() as client: - # redirect_loc = await _get_redirected_url(client, method, url, request_headers) - try: - status_code, head_info, content = await _file_full_header( - app=app, - save_path=save_path, - head_path=head_path, - client=client, - method="HEAD", - url=hf_url, - headers=request_headers, - allow_cache=allow_cache, - ) - if status_code == 404: - yield 404 - yield { - "x-error-code": "EntryNotFound", - "x-error-message": "Entry not found", - } - yield "Entry not found" - return - if status_code != 200: - if method.lower() == "head": - yield status_code - yield head_info - yield content - return - elif method.lower() == "get": - yield status_code - yield head_info - yield content - return - else: - raise Exception("Invalid method in _file_realtime_stream parameter.") - except httpx.ConnectError: - yield 504 - yield {} - yield b"" - return - - async with httpx.AsyncClient() as client: - file_size = int(head_info["content-length"]) - response_headers = {k: v for k, v in head_info.items()} - if "range" in request_headers: - start_pos, end_pos = parse_range_params( - request_headers.get("range", f"bytes={0}-{file_size-1}"), file_size - ) - response_headers["content-length"] = str(end_pos - start_pos + 1) - if commit is not None: - response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit - - if app.app_settings.config.offline and "etag" not in response_headers: - # Create fake headers when offline mode - sha256_hash = hashlib.sha256() - sha256_hash.update(hf_url.encode("utf-8")) - content_hash = sha256_hash.hexdigest() - response_headers["etag"] = f'"{content_hash[:32]}-10"' - yield 200 - yield response_headers - if method.lower() == "get": - async for each_chunk in _file_chunk_get( - app=app, - save_path=save_path, - head_path=head_path, - client=client, - method=method, - url=hf_url, - headers=request_headers, - allow_cache=allow_cache, - file_size=file_size, - ): - yield each_chunk - elif method.lower() == "head": - async for each_chunk in _file_chunk_head( - app=app, - save_path=save_path, - head_path=head_path, - client=client, - method=method, - url=hf_url, - headers=request_headers, - allow_cache=allow_cache, - file_size=0, - ): - yield each_chunk - else: - raise Exception(f"Unsupported method: {method}") + + generator = pathsinfo_generator(app, repo_type, org, repo, commit, [file_path], override_cache=False, method="post") + headers = await generator.__anext__() + content = await generator.__anext__() + try: + pathsinfo = json.loads(content) + except json.JSONDecodeError: + yield 504 + yield {} + yield b"" + return + + if len(pathsinfo) != 1: + yield 504 + yield {} + yield b"" + return + + pathinfo = pathsinfo[0] + if "size" not in pathinfo: + yield 504 + yield {} + yield b"" + return + file_size = pathinfo["size"] + + response_headers = {} + # Create content-length + start_pos, end_pos = parse_range_params( + request_headers.get("range", f"bytes={0}-{file_size-1}"), file_size + ) + response_headers["content-length"] = str(end_pos - start_pos + 1) + # Commit info + if commit is not None: + response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit + # Create fake headers when offline mode + sha256_hash = hashlib.sha256() + sha256_hash.update(hf_url.encode("utf-8")) + content_hash = sha256_hash.hexdigest() + if app.app_settings.config.offline: + response_headers["etag"] = f'"{content_hash[:32]}-10"' + else: + if method.lower() == "head": + async with httpx.AsyncClient() as client: + response = await client.request(method="head", url=hf_url,headers={},timeout=WORKER_API_TIMEOUT) + if "etag" in response.headers: + response_headers["etag"] = response.headers["etag"] + else: + response_headers["etag"] = f'"{content_hash[:32]}-10"' + yield 200 + yield response_headers + if method.lower() == "get": + async for each_chunk in _file_chunk_get( + app=app, + save_path=save_path, + head_path=head_path, + method=method, + url=hf_url, + headers=request_headers, + allow_cache=allow_cache, + file_size=file_size, + ): + yield each_chunk + elif method.lower() == "head": + async for each_chunk in _file_chunk_head( + app=app, + save_path=save_path, + head_path=head_path, + method=method, + url=hf_url, + headers=request_headers, + allow_cache=allow_cache, + file_size=0, + ): + yield each_chunk + else: + raise Exception(f"Unsupported method: {method}") async def file_get_generator( @@ -545,6 +584,10 @@ async def file_get_generator( ) return _file_realtime_stream( app=app, + repo_type=repo_type, + org=org, + repo=repo, + file_path=file_path, save_path=save_path, head_path=head_path, url=url, diff --git a/olah/proxy/meta.py b/olah/proxy/meta.py index f171740..c8a45fe 100644 --- a/olah/proxy/meta.py +++ b/olah/proxy/meta.py @@ -61,7 +61,7 @@ async def _meta_proxy_generator( save_path, response_status_code, response_headers, bytes(content) ) - +# TODO: remove param `request` async def meta_generator( app: FastAPI, repo_type: Literal["models", "datasets", "spaces"], diff --git a/olah/proxy/pathsinfo.py b/olah/proxy/pathsinfo.py index a1b6b15..481fad9 100644 --- a/olah/proxy/pathsinfo.py +++ b/olah/proxy/pathsinfo.py @@ -35,7 +35,8 @@ async def _pathsinfo_proxy( save_path: str, ): headers = {k: v for k, v in headers.items()} - headers.pop("content-length") + if "content-length" in headers: + headers.pop("content-length") async with httpx.AsyncClient(follow_redirects=True) as client: response = await client.request( method=method, @@ -64,13 +65,10 @@ async def pathsinfo_generator( commit: str, paths: List[str], override_cache: bool, - request: Request, + method: str, ): - headers = {k: v for k, v in request.headers.items()} - headers.pop("host") - + headers = {} # save - method = request.method.lower() repos_path = app.app_settings.repos_path final_content = [] diff --git a/olah/server.py b/olah/server.py index 6169ad4..2fd9718 100644 --- a/olah/server.py +++ b/olah/server.py @@ -404,9 +404,9 @@ async def pathsinfo_proxy_common(repo_type: str, org: str, repo: str, commit: st if commit_sha is None: return error_repo_not_found() if not app.app_settings.config.offline and commit_sha != commit: - generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=True, request=request) + generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=True, method=request.method.lower()) else: - generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=False, request=request) + generator = pathsinfo_generator(app, repo_type, org, repo, commit_sha, paths, override_cache=False, method=request.method.lower()) headers = await generator.__anext__() return StreamingResponse(generator, headers=headers) except httpx.ConnectTimeout: