diff --git a/.gitignore b/.gitignore index f77b7f9e..1c597fad 100644 --- a/.gitignore +++ b/.gitignore @@ -196,3 +196,6 @@ templates/* # Sampler overrides folder sampler_overrides/* !sampler_overrides/sample_preset.yml + +# Gpu lib preferences file +gpu_lib.txt diff --git a/start.py b/start.py index 458fe885..24ad5d5a 100644 --- a/start.py +++ b/start.py @@ -4,30 +4,94 @@ import argparse import os import pathlib +import platform import subprocess +import sys from common.args import convert_args_to_dict, init_argparser +def get_user_choice(question, options_dict): + """ + Gets user input in a commandline script. + + Originally from: https://github.com/oobabooga/text-generation-webui/blob/main/one_click.py#L213 + """ + + print() + print(question) + print() + + for key, value in options_dict.items(): + print(f"{key}) {value.get('pretty')}") + + print() + + choice = input("Input> ").upper() + while choice not in options_dict.keys(): + print("Invalid choice. Please try again.") + choice = input("Input> ").upper() + + return choice + + def get_install_features(): """Fetches the appropriate requirements file depending on the GPU""" install_features = None - ROCM_PATH = os.environ.get("ROCM_PATH") - CUDA_PATH = os.environ.get("CUDA_PATH") + possible_features = ["cu121", "cu118", "amd"] + + # Try getting the GPU lib from a file + saved_lib_path = pathlib.Path("gpu_lib.txt") + if saved_lib_path.exists(): + with open(saved_lib_path.resolve(), "r") as f: + lib = f.readline() + + # Assume default if the file is invalid + if lib not in possible_features: + print( + f"WARN: GPU library {lib} not found. " + "Skipping GPU-specific dependencies.\n" + "WARN: Please delete gpu_lib.txt and restart " + "if you want to change your selection." + ) + return + + print(f"Using {lib} dependencies from your preferences.") + install_features = lib + else: + # Ask the user for the GPU lib + gpu_lib_choices = { + "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, + "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, + "C": {"pretty": "AMD", "internal": "amd"}, + } + user_input = get_user_choice( + "Select your GPU. If you don't know, select Cuda 12.x (A)", + gpu_lib_choices, + ) + + install_features = gpu_lib_choices.get(user_input, {}).get("internal") + + # Write to a file for subsequent runs + with open(saved_lib_path.resolve(), "w") as f: + f.write(install_features) + print( + "Saving your choice to gpu_lib.txt. " + "Delete this file and restart if you want to change your selection." + ) - # TODO: Check if the user has an AMD gpu on windows - if ROCM_PATH: - install_features = "amd" + if install_features == "amd": + # Exit if using AMD and Windows + if platform.system() == "Windows": + print( + "ERROR: TabbyAPI does not support AMD and Windows. " + "Please use Linux and ROCm 5.6. Exiting." + ) + sys.exit(0) - # Also override env vars for ROCm support on non-supported GPUs + # Override env vars for ROCm support on non-supported GPUs os.environ["ROCM_PATH"] = "/opt/rocm" os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0" os.environ["HCC_AMDGPU_TARGET"] = "gfx1030" - elif CUDA_PATH: - cuda_version = pathlib.Path(CUDA_PATH).name - if "12" in cuda_version: - install_features = "cu121" - elif "11" in cuda_version: - install_features = "cu118" return install_features