Skip to content

Commit

Permalink
Start: Add gpu_lib argument
Browse files Browse the repository at this point in the history
Argument to override the selected GPU library. Useful for daemoniztion
when running for the first time.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Apr 9, 2024
1 parent d759a15 commit de41e9f
Showing 1 changed file with 45 additions and 36 deletions.
81 changes: 45 additions & 36 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from common.args import convert_args_to_dict, init_argparser


def get_user_choice(question, options_dict):
def get_user_choice(question: str, options_dict: dict):
"""
Gets user input in a commandline script.
Expand All @@ -34,50 +34,54 @@ def get_user_choice(question, options_dict):
return choice


def get_install_features():
def get_install_features(lib_name: str = None):
"""Fetches the appropriate requirements file depending on the GPU"""
install_features = None
possible_features = ["cu121", "cu118", "amd"]

# Try getting the GPU lib from a file
saved_lib_path = pathlib.Path("gpu_lib.txt")
if saved_lib_path.exists():
with open(saved_lib_path.resolve(), "r") as f:
lib = f.readline().strip()

# Assume default if the file is invalid
if lib not in possible_features:
if lib_name:
print("Overriding GPU lib name from args.")
else:
# Try getting the GPU lib from file
if saved_lib_path.exists():
print(saved_lib_path)
with open(saved_lib_path.resolve(), "r") as f:
lib = f.readline().strip()
else:
# Ask the user for the GPU lib
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
"C": {"pretty": "AMD", "internal": "amd"},
}
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)

lib_name = gpu_lib_choices.get(user_input, {}).get("internal")

# Write to a file for subsequent runs
with open(saved_lib_path.resolve(), "w") as f:
f.write(lib_name)
print(
f"WARN: GPU library {lib} not found. "
"Skipping GPU-specific dependencies.\n"
"WARN: Please delete gpu_lib.txt and restart "
"if you want to change your selection."
"Saving your choice to gpu_lib.txt. "
"Delete this file and restart if you want to change your selection."
)
return

print(f"Using {lib} dependencies from your preferences.")
install_features = lib
# Assume default if the file is invalid
if lib_name and lib_name in possible_features:
print(f"Using {lib_name} dependencies from your preferences.")
install_features = lib_name
else:
# Ask the user for the GPU lib
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
"C": {"pretty": "AMD", "internal": "amd"},
}
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
print(
f"WARN: GPU library {lib} not found. "
"Skipping GPU-specific dependencies.\n"
"WARN: Please delete gpu_lib.txt and restart "
"if you want to change your selection."
)

install_features = gpu_lib_choices.get(user_input, {}).get("internal")

# Write to a file for subsequent runs
with open(saved_lib_path.resolve(), "w") as f:
f.write(install_features)
print(
"Saving your choice to gpu_lib.txt. "
"Delete this file and restart if you want to change your selection."
)
return

if install_features == "amd":
# Exit if using AMD and Windows
Expand Down Expand Up @@ -111,6 +115,11 @@ def add_start_args(parser: argparse.ArgumentParser):
action="store_true",
help="Don't upgrade wheel dependencies (exllamav2, torch)",
)
start_group.add_argument(
"--gpu-lib",
type=str,
help="Select GPU library. Options: cu121, cu118, amd",
)


if __name__ == "__main__":
Expand All @@ -124,7 +133,7 @@ def add_start_args(parser: argparse.ArgumentParser):
if args.ignore_upgrade:
print("Ignoring pip dependency upgrade due to user request.")
else:
install_features = None if args.nowheel else get_install_features()
install_features = None if args.nowheel else get_install_features(args.gpu_lib)
features = f"[{install_features}]" if install_features else ""

# pip install .[features]
Expand Down

0 comments on commit de41e9f

Please sign in to comment.