From 992d0de6bbcf3a55bf6dfdfc25dca0eb0574fc45 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 22 Dec 2023 22:18:37 -0500 Subject: [PATCH] formatting --- setup.py | 69 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/setup.py b/setup.py index 6ff37041..c7f4bf4c 100644 --- a/setup.py +++ b/setup.py @@ -11,12 +11,15 @@ MAIN_CUDA_VERSION = "12.2" + def get_hipcc_rocm_version(): # Run the hipcc --version command - result = subprocess.run(['hipcc', '--version'], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True) + result = subprocess.run( + ["hipcc", "--version"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) # Check if the command was executed successfully if result.returncode != 0: @@ -24,7 +27,7 @@ def get_hipcc_rocm_version(): return None # Extract the version using a regular expression - match = re.search(r'HIP version: (\S+)', result.stdout) + match = re.search(r"HIP version: (\S+)", result.stdout) if match: # Return the version string return match.group(1) @@ -38,25 +41,27 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ - nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], - universal_newlines=True) + nvcc_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) output = nvcc_output.split() release_idx = output.index("release") + 1 nvcc_cuda_version = parse(output[release_idx].split(",")[0]) return nvcc_cuda_version + def get_requirements(): if ROCM_HOME is not None: - req_file = 'requirements-amd.txt' + req_file = "requirements-amd.txt" elif CUDA_HOME is not None: cuda_version = get_nvcc_cuda_version(CUDA_HOME) if cuda_version == Version("11.8"): - req_file = 'requirements-cu118.txt' + req_file = "requirements-cu118.txt" else: - req_file = 'requirements.txt' + req_file = "requirements.txt" else: - req_file = 'requirements-cpu.txt' - + req_file = "requirements-cpu.txt" + with open(req_file) as f: requirements = f.read().splitlines() return requirements @@ -72,8 +77,9 @@ def find_version(filepath: str) -> str: Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py """ with open(filepath) as fp: - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - fp.read(), re.M) + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M + ) if version_match: return version_match.group(1) raise RuntimeError("Unable to find version string.") @@ -81,7 +87,7 @@ def find_version(filepath: str) -> str: def get_tabbyapi_version() -> str: version = find_version(get_path("tabbyapi", "__init__.py")) - + if ROCM_HOME is not None: # get the HIP version hipcc_version = get_hipcc_rocm_version() @@ -93,22 +99,22 @@ def get_tabbyapi_version() -> str: if cuda_version is not None: cuda_version_str = str(cuda_version) # Split the version into numerical and suffix parts - version_parts = version.split('-') + version_parts = version.split("-") version_num = version_parts[0] - version_suffix = version_parts[1] if len(version_parts) > 1 else '' - + version_suffix = version_parts[1] if len(version_parts) > 1 else "" + if cuda_version_str != MAIN_CUDA_VERSION: cuda_version_str = cuda_version_str.replace(".", "")[:3] version_num += f"+cu{cuda_version_str}" - + # Reassemble the version string with the suffix, if any - version = version_num + ('-' + - version_suffix if version_suffix else '') + version = version_num + ("-" + version_suffix if version_suffix else "") else: version += "+cpu" - + return version + def read_readme() -> str: p = get_path("README.md") if os.path.isfile(p): @@ -116,6 +122,7 @@ def read_readme() -> str: else: return "" + setup( name="tabbyapi", version=find_version(get_path("tabbyapi", "__init__.py")), @@ -124,22 +131,22 @@ def read_readme() -> str: long_description_content_type="text/markdown", author="The Royal Lab", url="https://github.com/theroyallab/tabbyAPI", - license='AGPL 3.0', - packages=find_packages(exclude=["tests", "examples", - "models", "loras", - "templates", "Docker"]), + license="AGPL 3.0", + packages=find_packages( + exclude=["tests", "examples", "models", "loras", "templates", "Docker"] + ), install_requires=get_requirements(), - python_requires='>=3.10, <3.12', + python_requires=">=3.10, <3.12", entry_points={ - 'console_scripts': [ - 'tabbyapi=tabbyapi.main:main', + "console_scripts": [ + "tabbyapi=tabbyapi.main:main", ], }, classifiers=[ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", # noqa: E501 "Topic :: Scientific/Engineering :: Artificial Intelligence", ], # package_data={"tabbyapi": ["config.yml"]}, -) \ No newline at end of file +)