diff --git a/tools/inference_engine/reference_loader.py b/tools/inference_engine/reference_loader.py index 91232eef..4b560393 100644 --- a/tools/inference_engine/reference_loader.py +++ b/tools/inference_engine/reference_loader.py @@ -50,7 +50,7 @@ def load_by_id( # If the references are not already loaded, encode them prompt_tokens = [ self.encode_reference( - decoder_model=self.decoder_model, + # decoder_model=self.decoder_model, reference_audio=audio_to_bytes(str(ref_audio)), enable_reference_audio=True, ) diff --git a/tools/run_webui.py b/tools/run_webui.py index 6b0ab490..5844b72c 100644 --- a/tools/run_webui.py +++ b/tools/run_webui.py @@ -45,6 +45,11 @@ def parse_args(): args = parse_args() args.precision = torch.half if args.half else torch.bfloat16 + # Check if MPS is available + if torch.backends.mps.is_available(): + args.device = "mps" + logger.info("mps is available, running on mps.") + # Check if CUDA is available if not torch.cuda.is_available(): logger.info("CUDA is not available, running on CPU.") diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py index 549ad8d4..c3f0a896 100644 --- a/tools/server/model_manager.py +++ b/tools/server/model_manager.py @@ -34,6 +34,11 @@ def __init__( self.precision = torch.half if half else torch.bfloat16 + # Check if MPS is available + if torch.backends.mps.is_available(): + self.device = "mps" + logger.info("mps is available, running on mps.") + # Check if CUDA is available if not torch.cuda.is_available(): self.device = "cpu"