diff --git a/xinference/model/image/stable_diffusion/core.py b/xinference/model/image/stable_diffusion/core.py index eda1e1c412..315f5a8b0f 100644 --- a/xinference/model/image/stable_diffusion/core.py +++ b/xinference/model/image/stable_diffusion/core.py @@ -23,6 +23,7 @@ from ....constants import XINFERENCE_IMAGE_DIR from ....types import Image, ImageList +from ...llm.utils import select_device class DiffusionModel: @@ -36,8 +37,17 @@ def __init__( self._kwargs = kwargs def load(self): - import torch - from diffusers import AutoPipelineForText2Image + try: + from diffusers import AutoPipelineForText2Image + except ImportError: + raise ImportError( + f"Failed to import module 'diffusers'. Please make sure 'diffusers' is installed.\n\n" + ) + + device = self._kwargs.get("device", "auto") + self._kwargs["device"] = select_device(device) + if self._kwargs["device"] == "cuda": + self._kwargs.setdefault("device_map", "auto") self._model = AutoPipelineForText2Image.from_pretrained( self._model_path, @@ -46,9 +56,7 @@ def load(self): # torch_dtype=torch.float16, # use_safetensors=True, ) - if torch.cuda.is_available(): - self._model = self._model.to("cuda") - elif torch.backends.mps.is_available(): + if self._kwargs["device"] == "mps": self._model = self._model.to("mps") # Recommended if your computer has < 64 GB of RAM self._model.enable_attention_slicing() diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 05a5d64d71..75a5e2a31c 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -30,7 +30,7 @@ ) from ..core import LLM from ..llm_family import LLMFamilyV1, LLMSpecV1 -from ..utils import ChatModelMixin +from ..utils import ChatModelMixin, select_device logger = logging.getLogger(__name__) @@ -120,7 +120,7 @@ def load(self): quantization = self.quantization num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0 device = self._pytorch_model_config.get("device", "auto") - self._pytorch_model_config["device"] = self._select_device(device) + self._pytorch_model_config["device"] = select_device(device) self._device = self._pytorch_model_config["device"] if self._device == "cpu": @@ -183,33 +183,6 @@ def load(self): self._model.to(self._device) logger.debug(f"Model Memory: {self._model.get_memory_footprint()}") - def _select_device(self, device: str) -> str: - try: - import torch - except ImportError: - raise ImportError( - f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n" - ) - - if device == "auto": - # When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False - if torch.cuda.is_available(): - return "cuda" - elif torch.backends.mps.is_available(): - return "mps" - return "cpu" - elif device == "cuda": - if not torch.cuda.is_available(): - raise ValueError("cuda is unavailable in your environment") - elif device == "mps": - if not torch.backends.mps.is_available(): - raise ValueError("mps is unavailable in your environment") - elif device == "cpu": - pass - else: - raise ValueError(f"Device {device} is not supported in temporary") - return device - @classmethod def match( cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 706652a79d..5bae9c8807 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -268,3 +268,31 @@ def is_valid_model_name(model_name: str) -> bool: import re return re.match(r"^[A-Za-z0-9][A-Za-z0-9_\-]*$", model_name) is not None + + +def select_device(device: str) -> str: + try: + import torch + except ImportError: + raise ImportError( + f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n" + ) + + if device == "auto": + # When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + return "cpu" + elif device == "cuda": + if not torch.cuda.is_available(): + raise ValueError("cuda is unavailable in your environment") + elif device == "mps": + if not torch.backends.mps.is_available(): + raise ValueError("mps is unavailable in your environment") + elif device == "cpu": + pass + else: + raise ValueError(f"Device {device} is not supported in temporary") + return device