Skip to content

Commit

Permalink
v3.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Jul 1, 2024
1 parent 68d24d9 commit 14d7d9b
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions ct2_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,31 @@
import os
import glob
from pathlib import Path
from PySide6.QtWidgets import QApplication
from ct2_gui import MyWindow
from ct2_utils import CheckQuantizationSupport

def set_cuda_paths():
script_dir = Path(__file__).parent.resolve()
nvidia_base_path = script_dir / 'Lib' / 'site-packages' / 'nvidia'
cublas_bin_path = script_dir / 'Lib' / 'site-packages' / 'nvidia' / 'cublas' / 'bin'
cudnn_bin_path = script_dir / 'Lib' / 'site-packages' / 'nvidia' / 'cudnn' / 'bin'

# Set CUDA_PATH and CUDA_PATH_V12_2
for env_var in ['CUDA_PATH', 'CUDA_PATH_V12_2']:
current_path = os.environ.get(env_var, '')
new_paths = [str(nvidia_base_path), str(cublas_bin_path), str(cudnn_bin_path)]
os.environ[env_var] = os.pathsep.join(filter(None, new_paths + [current_path]))

# Add nvidia folder, cudnn bin folder, and cublas bin folder to system PATH
current_path = os.environ.get('PATH', '')
new_paths = [str(nvidia_base_path), str(cublas_bin_path), str(cudnn_bin_path)]
new_path = os.pathsep.join(filter(None, new_paths + [current_path]))
os.environ['PATH'] = new_path
cuda_path = script_dir / 'Lib' / 'site-packages' / 'nvidia'
cublas_path = cuda_path / 'cublas' / 'bin'
cudnn_path = cuda_path / 'cudnn' / 'bin'

paths_to_add = [str(cuda_path), str(cublas_path), str(cudnn_path)]

env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_2', 'PATH']

for env_var in env_vars:
current_value = os.environ.get(env_var, '')
new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add)
os.environ[env_var] = new_value

print("CUDA paths have been set or updated in the environment variables.")

set_cuda_paths()

from PySide6.QtWidgets import QApplication
from ct2_gui import MyWindow
from ct2_utils import CheckQuantizationSupport

if __name__ == "__main__":
set_cuda_paths()

quantization_checker = CheckQuantizationSupport()
cuda_available = quantization_checker.has_cuda_device()
quantization_checker.update_supported_quantizations()
Expand Down

0 comments on commit 14d7d9b

Please sign in to comment.