From 48edab8041741a82a1fd2f4d463cc0f393561b05 Mon Sep 17 00:00:00 2001 From: Akash kaothalkar <61960177+Akashcodes732@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:02:07 +0530 Subject: [PATCH] [Bugfix][Hardware][POWERPC] Fix auto dtype failure in case of POWER10 (#11331) Signed-off-by: Akash Kaothalkar <0052v2@linux.vnet.ibm.com> --- vllm/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 0e886e18fcd6d..6badae24d9d7d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import current_platform +from vllm.platforms import current_platform, interface from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -2199,6 +2199,17 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype + if (current_platform.is_cpu() + and current_platform.get_cpu_architecture() + == interface.CpuArchEnum.POWERPC + and (config_dtype == torch.float16 + or config_dtype == torch.float32)): + logger.info( + "For POWERPC, we cast models to bfloat16 instead of " + "using float16 by default. Float16 is not currently " + "supported for POWERPC.") + torch_dtype = torch.bfloat16 + if current_platform.is_hpu() and config_dtype == torch.float16: logger.info( "For HPU, we cast models to bfloat16 instead of"