diff --git a/olah/proxy/files.py b/olah/proxy/files.py index 1fd9b8f..9da6d91 100644 --- a/olah/proxy/files.py +++ b/olah/proxy/files.py @@ -85,125 +85,6 @@ def get_contiguous_ranges( range_start_pos = end_pos return ranges_and_cache_list -async def _file_full_header( - app, - save_path: str, - head_path: str, - client: httpx.AsyncClient, - method: str, - url: str, - headers: Dict[str, str], - allow_cache: bool, -) -> Tuple[int, Dict[str, str], bytes]: - assert method.lower() == "head" - if not app.app_settings.config.offline: - if os.path.exists(head_path): - cache_rq = await read_cache_request(head_path) - response_headers_dict = { - k.lower(): v for k, v in cache_rq["headers"].items() - } - if "location" in response_headers_dict: - parsed_url = urlparse(response_headers_dict["location"]) - if len(parsed_url.netloc) != 0: - new_loc = urljoin( - app.app_settings.config.mirror_lfs_url_base(), - get_url_tail(response_headers_dict["location"]), - ) - response_headers_dict["location"] = new_loc - return cache_rq["status_code"], response_headers_dict, cache_rq["content"] - else: - if "range" in headers: - headers.pop("range") - response = await client.request( - method=method, - url=url, - headers=headers, - timeout=WORKER_API_TIMEOUT, - ) - response_headers_dict = {k.lower(): v for k, v in response.headers.items()} - if allow_cache and method.lower() == "head": - if response.status_code == 200: - await write_cache_request( - head_path, - response.status_code, - response_headers_dict, - response.content, - ) - elif response.status_code >= 300 and response.status_code <= 399: - from_url = urlparse(url) - parsed_url = urlparse(response.headers["location"]) - if len(parsed_url.netloc) != 0: - new_loc = urljoin( - app.app_settings.config.mirror_lfs_url_base(), - get_url_tail(response.headers["location"]), - ) - response_headers_dict["location"] = new_loc - # Redirect, add original location info - if check_url_has_param_name( - response_headers_dict["location"], ORIGINAL_LOC - ): - raise Exception(f"Invalid field {ORIGINAL_LOC} in the url.") - else: - response_headers_dict["location"] = add_query_param( - response_headers_dict["location"], - ORIGINAL_LOC, - response.headers["location"], - ) - await write_cache_request( - head_path, - response.status_code, - response_headers_dict, - response.content, - ) - elif response.status_code == 403: - pass - elif response.status_code == 404: - pass - else: - raise Exception( - f"Unexpected HTTP status code {response.status_code}" - ) - return response.status_code, response_headers_dict, response.content - else: - if os.path.exists(head_path): - cache_rq = await read_cache_request(head_path) - response_headers_dict = { - k.lower(): v for k, v in cache_rq["headers"].items() - } - else: - response_headers_dict = {} - cache_rq = { - "status_code": 200, - "headers": response_headers_dict, - "content": b"", - } - - new_headers = {} - if "content-type" in response_headers_dict: - new_headers["content-type"] = response_headers_dict["content-type"] - if "content-length" in response_headers_dict: - new_headers["content-length"] = response_headers_dict["content-length"] - if HUGGINGFACE_HEADER_X_REPO_COMMIT.lower() in response_headers_dict: - new_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = ( - response_headers_dict.get(HUGGINGFACE_HEADER_X_REPO_COMMIT.lower(), "") - ) - if HUGGINGFACE_HEADER_X_LINKED_ETAG.lower() in response_headers_dict: - new_headers[HUGGINGFACE_HEADER_X_LINKED_ETAG.lower()] = ( - response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_ETAG.lower(), "") - ) - if HUGGINGFACE_HEADER_X_LINKED_SIZE.lower() in response_headers_dict: - new_headers[HUGGINGFACE_HEADER_X_LINKED_SIZE.lower()] = ( - response_headers_dict.get(HUGGINGFACE_HEADER_X_LINKED_SIZE.lower(), "") - ) - if "etag" in response_headers_dict: - new_headers["etag"] = response_headers_dict["etag"] - if "location" in response_headers_dict: - new_headers["location"] = urljoin( - app.app_settings.config.mirror_lfs_url_base(), - get_url_tail(response_headers_dict["location"]), - ) - return cache_rq["status_code"], new_headers, cache_rq["content"] - async def _get_file_range_from_cache( cache_file: OlahCache, start_pos: int, end_pos: int @@ -238,7 +119,8 @@ async def _get_file_range_from_remote( end_pos: int, ): headers = {} - headers["authorization"] = remote_info.headers.get("authorization", None) + if remote_info.headers.get("authorization", None) is not None: + headers["authorization"] = remote_info.headers.get("authorization", None) headers["range"] = f"bytes={start_pos}-{end_pos - 1}" chunk_bytes = 0 @@ -433,6 +315,32 @@ async def _file_chunk_head( yield b"" +async def _resource_etag(hf_url: str, authorization: Optional[str]=None, offline: bool = False) -> Optional[str]: + ret_etag = None + sha256_hash = hashlib.sha256() + sha256_hash.update(hf_url.encode("utf-8")) + content_hash = sha256_hash.hexdigest() + if offline: + ret_etag = f'"{content_hash[:32]}-10"' + else: + etag_headers = {} + if authorization is not None: + etag_headers["authorization"] = authorization + try: + async with httpx.AsyncClient() as client: + response = await client.request( + method="head", + url=hf_url, + headers=etag_headers, + timeout=WORKER_API_TIMEOUT, + ) + if "etag" in response.headers: + ret_etag = response.headers["etag"] + else: + ret_etag = f'"{content_hash[:32]}-10"' + except httpx.TimeoutException: + ret_etag = None + return ret_etag async def _file_realtime_stream( app, repo_type: Literal["models", "datasets", "spaces"], @@ -531,28 +439,22 @@ async def _file_realtime_stream( if commit is not None: response_headers[HUGGINGFACE_HEADER_X_REPO_COMMIT.lower()] = commit # Create fake headers when offline mode - sha256_hash = hashlib.sha256() - sha256_hash.update(hf_url.encode("utf-8")) - content_hash = sha256_hash.hexdigest() - if app.app_settings.config.offline: - response_headers["etag"] = f'"{content_hash[:32]}-10"' + etag = await _resource_etag( + hf_url=hf_url, + authorization=request.headers.get("authorization", None), + offline=app.app_settings.config.offline, + ) + response_headers["etag"] = etag + + if etag is None: + error_response = error_proxy_timeout() + yield error_response.status_code + yield error_response.headers + yield error_response.body + return else: - if method.lower() == "head": - async with httpx.AsyncClient() as client: - response = await client.request( - method="head", - url=hf_url, - headers={ - "authorization": request.headers.get("authorization", None) - }, - timeout=WORKER_API_TIMEOUT, - ) - if "etag" in response.headers: - response_headers["etag"] = response.headers["etag"] - else: - response_headers["etag"] = f'"{content_hash[:32]}-10"' - yield 200 - yield response_headers + yield 200 + yield response_headers async with httpx.AsyncClient() as client: if method.lower() == "get":