diff --git a/olah/configs.py b/olah/configs.py index e6b954b..6701d5a 100644 --- a/olah/configs.py +++ b/olah/configs.py @@ -5,7 +5,7 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. -from typing import List, Optional +from typing import List, Optional, Union import toml import re import fnmatch @@ -78,7 +78,7 @@ class OlahConfig(object): def __init__(self, path: Optional[str] = None) -> None: # basic - self.host = "localhost" + self.host: Union[List[str], str] = "localhost" self.port = 8090 self.ssl_key = None self.ssl_cert = None @@ -90,10 +90,10 @@ def __init__(self, path: Optional[str] = None) -> None: self.mirror_scheme: str = "http" if self.ssl_key is None else "https" self.mirror_netloc: str = ( - f"{self.host if self.host != '0.0.0.0' else 'localhost'}:{self.port}" + f"{self.host if self._is_specific_addr(self.host) else 'localhost'}:{self.port}" ) self.mirror_lfs_netloc: str = ( - f"{self.host if self.host != '0.0.0.0' else 'localhost'}:{self.port}" + f"{self.host if self._is_specific_addr(self.host) else 'localhost'}:{self.port}" ) self.mirrors_path: List[str] = [] @@ -105,6 +105,12 @@ def __init__(self, path: Optional[str] = None) -> None: if path is not None: self.read_toml(path) + + def _is_specific_addr(self, host: Union[List[str], str]) -> bool: + if isinstance(host, str): + return host not in ['0.0.0.0', '::'] + else: + return False def hf_url_base(self) -> str: return f"{self.hf_scheme}://{self.hf_netloc}" diff --git a/olah/server.py b/olah/server.py index 2736f37..6eba6cf 100644 --- a/olah/server.py +++ b/olah/server.py @@ -713,6 +713,10 @@ def is_default_value(args, arg_name): args.ssl_cert = config.ssl_cert if is_default_value(args, "repos_path"): args.repos_path = config.repos_path + + # Post processing + if "," in args.host: + args.host = args.host.split(",") app.app_settings = AppSettings( config=config,