diff --git a/start.bat b/start.bat index 688bfc94..5816eed8 100644 --- a/start.bat +++ b/start.bat @@ -1,11 +1,20 @@ -:: From https://github.com/jllllll/windows-venv-installers/blob/main/Powershell/run-ps-script.cmd @echo off -set SCRIPT_NAME=start.ps1 +:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades +:: This is intended for users who want to start the API and have everything upgraded and installed -:: This will run the Powershell script named above in the current directory -:: This is intended for systems who have not changed the script execution policy from default -:: These systems will be unable to directly execute Powershell scripts unless done through CMD.exe like below +cd "%~dp0" -if not exist "%~dp0\%SCRIPT_NAME%" ( echo %SCRIPT_NAME% not found! && pause && goto eof ) -call powershell.exe -executionpolicy Bypass ". '%~dp0\start.ps1' %*" +:: Don't create a venv if a conda environment is active +if exist "%CONDA_PREFIX%" ( + echo It looks like you're in a conda environment. Skipping venv check. +) else ( + if not exist "venv\" ( + echo "Venv doesn't exist! Creating one for you." + python -m venv venv + call .\venv\Scripts\activate.bat + ) +) + +:: Call the python script with batch args +call python start.py %* diff --git a/start.ps1 b/start.ps1 deleted file mode 100644 index cc17d6bd..00000000 --- a/start.ps1 +++ /dev/null @@ -1,75 +0,0 @@ -# Arg parsing -param( - [switch]$ignore_upgrade = $false, - [switch]$nowheel = $false, - [switch]$activate_venv = $false -) - -# Gets the currently installed CUDA version -function GetRequirementsFile { - $GpuInfo = (Get-WmiObject Win32_VideoController).Name - if ($GpuInfo.Contains("AMD")) { - Write-Error "AMD/ROCm isn't supported on Windows. Please switch to linux." - exit - } - - # Install nowheel if specified - if ($nowheel) { - Write-Host "Not installing wheels due to user request." - return "requirements-nowheel" - } - - $CudaPath = $env:CUDA_PATH - $CudaVersion = Split-Path $CudaPath -Leaf - - # Decide requirements based on CUDA version - if ($CudaVersion.Contains("12")) { - return "requirements" - } elseif ($CudaVersion.Contains("11.8")) { - return "requirements-cu118" - } else { - Write-Host "Script cannot find your CUDA installation. installing from requirements-nowheel.txt" - return "requirements-nowheel" - } -} - -# Make a venv and enter it -function CreateAndActivateVenv { - # Is the user using conda? - if ($null -ne $env:CONDA_PREFIX) { - Write-Host "It looks like you're in a conda environment. Skipping venv check." - return - } - - $VenvDir = "$PSScriptRoot\venv" - - if (!(Test-Path -Path $VenvDir)) { - Write-Host "Venv doesn't exist! Creating one for you." - python -m venv venv - } - - . "$VenvDir\Scripts\activate.ps1" - - if ($activate_venv) { - Write-Host "Stopping at venv activation due to user request." - exit - } -} - -# Entrypoint for API start -function StartAPI { - pip -V - if ($ignore_upgrade) { - Write-Host "Ignoring pip dependency upgrade due to user request." - } else { - pip install --upgrade -r "$RequirementsFile.txt" - } - - python main.py -} - -# Navigate to the script directory -Set-Location $PSScriptRoot -$RequirementsFile = GetRequirementsFile -CreateAndActivateVenv -StartAPI diff --git a/start.py b/start.py new file mode 100644 index 00000000..ba220e62 --- /dev/null +++ b/start.py @@ -0,0 +1,58 @@ +"""Utility to automatically upgrade and start the API""" +import argparse +import os +import pathlib +import subprocess +from main import entrypoint + + +def get_requirements_file(): + """Fetches the appropriate requirements file depending on the GPU""" + requirements_name = "requirements-nowheel" + ROCM_PATH = os.environ.get("ROCM_PATH") + CUDA_PATH = os.environ.get("CUDA_PATH") + + # TODO: Check if the user has an AMD gpu on windows + if ROCM_PATH: + requirements_name = "requirements-amd" + elif CUDA_PATH: + cuda_version = pathlib.Path(CUDA_PATH).name + if "12" in cuda_version: + requirements_name = "requirements" + elif "11" in cuda_version: + requirements_name = "requirements-cu118" + + return requirements_name + + +def get_argparser(): + """Fetches the argparser for this script""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-iu", + "--ignore-upgrade", + action="store_true", + help="Ignore requirements upgrade", + ) + parser.add_argument( + "-nw", + "--nowheel", + action="store_true", + help="Don't upgrade wheel dependencies (exllamav2, torch)", + ) + return parser + + +if __name__ == "__main__": + parser = get_argparser() + args = parser.parse_args() + + if args.ignore_upgrade: + print("Ignoring pip dependency upgrade due to user request.") + else: + requirements_file = ( + "requirements-nowheel" if args.nowheel else get_requirements_file() + ) + subprocess.run(f"pip install -U -r {requirements_file}.txt") + + entrypoint()