From ed05f267e13a4057252fe22ea0a84f413381fcd9 Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Mon, 4 Nov 2024 18:40:33 +0200 Subject: [PATCH 1/8] [Core] Enhance memory profiling in determine_num_available_blocks (#9232) Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 117 +++++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8928936b4f9fc..08216a3392466 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -195,62 +195,73 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() - torch.cuda.synchronize() + try: + self.model_runner.profile_run() + torch.cuda.synchronize() - self._assert_memory_footprint_increased_during_profiling() + self._assert_memory_footprint_increased_during_profiling() - # Get the peak memory allocation recorded by torch - peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + # Get the peak memory allocation recorded by torch + peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - # Check for any memory left around that may have been allocated on the - # gpu outside of `torch`. NCCL operations, for example, can use a few - # GB during a forward pass - torch.cuda.empty_cache() - torch_allocated_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = self.get_cache_block_size_bytes() - if cache_block_size == 0: - num_gpu_blocks = 0 - num_cpu_blocks = 0 - else: - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - - logger.info( - "Memory profiling results: total_gpu_memory=%.2fGiB" - " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" - " memory_usage_post_profile=%.2fGib" - " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), - (total_gpu_memory - free_memory_pre_profile) / (1024**3), - (peak_memory - non_torch_allocations) / (1024**3), - total_allocated_bytes / (1024**3), - non_torch_allocations / (1024**3), - available_kv_cache_memory / (1024**3), - self.cache_config.gpu_memory_utilization) - - # Final cleanup - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - gc.collect() - - return num_gpu_blocks, num_cpu_blocks + # Check for any memory left around that may have been allocated + # on thegpu outside of `torch`. NCCL operations, for example, + # can use a few GB during a forward pass + torch.cuda.empty_cache() + torch_allocated_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + + non_torch_allocations = (total_allocated_bytes - + torch_allocated_bytes) + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + cache_block_size = self.get_cache_block_size_bytes() + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int(available_kv_cache_memory // + cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + + logger.info( + "Memory profiling results: total_gpu_memory=%.2fGiB" + " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" + " memory_usage_post_profile=%.2fGib" + " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" + " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), + (total_gpu_memory - free_memory_pre_profile) / (1024**3), + (peak_memory - non_torch_allocations) / (1024**3), + total_allocated_bytes / (1024**3), + non_torch_allocations / (1024**3), + available_kv_cache_memory / (1024**3), + self.cache_config.gpu_memory_utilization) + + # Final cleanup + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + + return num_gpu_blocks, num_cpu_blocks + except Exception: + # Provide fallback values based on total GPU memory with a + # 70% utilization safety margin + cache_block_size = self.get_cache_block_size_bytes() + safe_gpu_blocks = int((total_gpu_memory * 0.7) // + cache_block_size) if cache_block_size else 0 + return safe_gpu_blocks, 0 def _assert_memory_footprint_increased_during_profiling(self): # NOTE(woosuk): Here we assume that the other processes using the same From e2fc283ddcdd17a39850a5f59e5577fa9393e4f2 Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Mon, 4 Nov 2024 21:26:30 +0200 Subject: [PATCH 2/8] [Core] Refactor memory profiling based on review feedback Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 08216a3392466..9d2cf6319de36 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -264,15 +264,31 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return safe_gpu_blocks, 0 def _assert_memory_footprint_increased_during_profiling(self): - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - free_gpu_memory, _ = torch.cuda.mem_get_info() - assert self.init_gpu_memory - free_gpu_memory > 0, ( + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + free_gpu_memory, total_memory = torch.cuda.mem_get_info() + memory_diff = self.init_gpu_memory - free_gpu_memory + + # If we've loaded model weights but memory shows no change, + # we're likely in a restricted environment + model_loaded = hasattr(self.model_runner, 'model') + memory_is_static = memory_diff == 0 + + is_restricted_env = model_loaded and memory_is_static + + if is_restricted_env: + logger.info( + "Detected restricted GPU environment. " + "Model is loaded but memory reports static usage. " + f"Free memory: {free_gpu_memory / (1024**3):.2f}GB, " + f"Total memory: {total_memory / (1024**3):.2f}GB" + ) + + assert memory_diff > 0 or is_restricted_env, ( "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks. From be35fb8e5c2cc27a71f40acd19397e8a2a1fc5d8 Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Mon, 4 Nov 2024 21:28:32 +0200 Subject: [PATCH 3/8] Revert "[Core] Enhance memory profiling in determine_num_available_blocks (#9232)" This reverts commit 581804fe9a44d2eb5bf78ac38dc97b9530fe5779. Revert "[Core] Enhance memory profiling in determine_num_available_blocks (#9232)" This reverts commit 581804fe9a44d2eb5bf78ac38dc97b9530fe5779. Reason: The initial implementation introduced a heuristic fallback with 70% GPU memory utilization, which may not be optimal for all configurations. There is a risk of overriding user configurations silently, and the fallback may lead to sub-optimal performance or out-of-memory (OOM) errors. This revert allows us to take a more robust approach that avoids these issues and provides clearer behavior for users configuring GPU memory utilization. Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 117 +++++++++++++++++++----------------------- 1 file changed, 53 insertions(+), 64 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9d2cf6319de36..8a4ec0f5d3562 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -195,73 +195,62 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - try: - self.model_runner.profile_run() - torch.cuda.synchronize() + self.model_runner.profile_run() + torch.cuda.synchronize() - self._assert_memory_footprint_increased_during_profiling() + self._assert_memory_footprint_increased_during_profiling() - # Get the peak memory allocation recorded by torch - peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + # Get the peak memory allocation recorded by torch + peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - # Check for any memory left around that may have been allocated - # on thegpu outside of `torch`. NCCL operations, for example, - # can use a few GB during a forward pass - torch.cuda.empty_cache() - torch_allocated_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - - non_torch_allocations = (total_allocated_bytes - - torch_allocated_bytes) - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = self.get_cache_block_size_bytes() - if cache_block_size == 0: - num_gpu_blocks = 0 - num_cpu_blocks = 0 - else: - num_gpu_blocks = int(available_kv_cache_memory // - cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - - logger.info( - "Memory profiling results: total_gpu_memory=%.2fGiB" - " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" - " memory_usage_post_profile=%.2fGib" - " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), - (total_gpu_memory - free_memory_pre_profile) / (1024**3), - (peak_memory - non_torch_allocations) / (1024**3), - total_allocated_bytes / (1024**3), - non_torch_allocations / (1024**3), - available_kv_cache_memory / (1024**3), - self.cache_config.gpu_memory_utilization) - - # Final cleanup - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - gc.collect() - - return num_gpu_blocks, num_cpu_blocks - except Exception: - # Provide fallback values based on total GPU memory with a - # 70% utilization safety margin - cache_block_size = self.get_cache_block_size_bytes() - safe_gpu_blocks = int((total_gpu_memory * 0.7) // - cache_block_size) if cache_block_size else 0 - return safe_gpu_blocks, 0 + # Check for any memory left around that may have been allocated on the + # gpu outside of `torch`. NCCL operations, for example, can use a few + # GB during a forward pass + torch.cuda.empty_cache() + torch_allocated_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + cache_block_size = self.get_cache_block_size_bytes() + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + + logger.info( + "Memory profiling results: total_gpu_memory=%.2fGiB" + " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" + " memory_usage_post_profile=%.2fGib" + " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" + " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), + (total_gpu_memory - free_memory_pre_profile) / (1024**3), + (peak_memory - non_torch_allocations) / (1024**3), + total_allocated_bytes / (1024**3), + non_torch_allocations / (1024**3), + available_kv_cache_memory / (1024**3), + self.cache_config.gpu_memory_utilization) + + # Final cleanup + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + + return num_gpu_blocks, num_cpu_blocks def _assert_memory_footprint_increased_during_profiling(self): # NOTE(woosuk): Here we assume that the other processes using the same From 930a8eb56858d0fe68fee102689e561c21ed9344 Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Mon, 4 Nov 2024 21:42:25 +0200 Subject: [PATCH 4/8] [Lint] Refactor logging to remove f-string usage for compliance Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8a4ec0f5d3562..aee9599d9932b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -253,31 +253,31 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return num_gpu_blocks, num_cpu_blocks def _assert_memory_footprint_increased_during_profiling(self): - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. free_gpu_memory, total_memory = torch.cuda.mem_get_info() memory_diff = self.init_gpu_memory - free_gpu_memory - + # If we've loaded model weights but memory shows no change, # we're likely in a restricted environment model_loaded = hasattr(self.model_runner, 'model') memory_is_static = memory_diff == 0 - + is_restricted_env = model_loaded and memory_is_static - + if is_restricted_env: - logger.info( - "Detected restricted GPU environment. " - "Model is loaded but memory reports static usage. " - f"Free memory: {free_gpu_memory / (1024**3):.2f}GB, " - f"Total memory: {total_memory / (1024**3):.2f}GB" - ) - + logger.info("Detected restricted GPU environment. " + "Model is loaded but memory reports static usage. " + "Free memory: {:.2f}GB, Total memory: {:.2f}GB".format( + free_gpu_memory / (1024**3), + total_memory / (1024**3))) + assert memory_diff > 0 or is_restricted_env, ( - "Error in memory profiling. " + "Error in memory profiling." f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks. From 18f25a3831e09ce1808514ecce79d87b29e3472a Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Mon, 4 Nov 2024 21:44:19 +0200 Subject: [PATCH 5/8] [Lint] Adjust logging to use % formatting for compliance Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index aee9599d9932b..2827d0468943e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -267,10 +267,10 @@ def _assert_memory_footprint_increased_during_profiling(self): if is_restricted_env: logger.info("Detected restricted GPU environment. " - "Model is loaded but memory reports static usage. " - "Free memory: {:.2f}GB, Total memory: {:.2f}GB".format( - free_gpu_memory / (1024**3), - total_memory / (1024**3))) + "Model is loaded but memory reports static usage. " + "Free memory: %.2fGB, Total memory: %.2fGB", + free_gpu_memory / (1024**3), + total_memory / (1024**3)) assert memory_diff > 0 or is_restricted_env, ( "Error in memory profiling." From 86c41e9f5aa2a46ed51a21af7fb51c31fb0cd230 Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Tue, 5 Nov 2024 00:16:53 +0200 Subject: [PATCH 6/8] [Core] Revert previous implementation and update worker to check for GPU blocks override Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 56 +++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2827d0468943e..d9b8ad8ab5df8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -179,33 +179,45 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks + Then, it calculates the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. .. tip:: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. + # Check if GPU blocks override is provided; skip profiling if so. + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + + if num_gpu_blocks_override is not None: + # Calculate num_cpu_blocks based on available swap space. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) if cache_block_size else 0 + num_cpu_blocks = max(num_cpu_blocks, 0) + + logger.info( + "Using num_gpu_blocks_override=%d, calculated num_cpu_blocks=%d", + num_gpu_blocks_override, num_cpu_blocks) + + return num_gpu_blocks_override, num_cpu_blocks + + # Proceed with full profiling if no override is provided. torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. + # Profile model memory usage with a forward pass. self.model_runner.profile_run() torch.cuda.synchronize() self._assert_memory_footprint_increased_during_profiling() - # Get the peak memory allocation recorded by torch + # Gather peak memory allocation. peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - # Check for any memory left around that may have been allocated on the - # gpu outside of `torch`. NCCL operations, for example, can use a few - # GB during a forward pass + # Adjust for non-torch memory allocations. torch.cuda.empty_cache() torch_allocated_bytes = torch.cuda.memory_stats( )["allocated_bytes.all.current"] @@ -219,33 +231,31 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. + # Calculate the number of blocks for both GPU and CPU. cache_block_size = self.get_cache_block_size_bytes() - if cache_block_size == 0: - num_gpu_blocks = 0 - num_cpu_blocks = 0 - else: - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) + num_gpu_blocks = int(available_kv_cache_memory // + cache_block_size) if cache_block_size else 0 + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) if cache_block_size else 0 + num_gpu_blocks, num_cpu_blocks = max(num_gpu_blocks, + 0), max(num_cpu_blocks, 0) logger.info( "Memory profiling results: total_gpu_memory=%.2fGiB" " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" " memory_usage_post_profile=%.2fGib" " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), + " gpu_memory_utilization=%.2f", + total_gpu_memory / (1024**3), (total_gpu_memory - free_memory_pre_profile) / (1024**3), (peak_memory - non_torch_allocations) / (1024**3), total_allocated_bytes / (1024**3), non_torch_allocations / (1024**3), available_kv_cache_memory / (1024**3), - self.cache_config.gpu_memory_utilization) + self.cache_config.gpu_memory_utilization, + ) - # Final cleanup + # Final cleanup. if self.model_runner.lora_manager: self.model_runner.remove_all_loras() gc.collect() From c1dea4930fa9c4a9a0b03006e60c51e38912ff1e Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Tue, 5 Nov 2024 00:18:11 +0200 Subject: [PATCH 7/8] [Core] Revert previous implementation and update worker to check for GPU blocks override Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d9b8ad8ab5df8..dd75daaebd5b4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -265,29 +265,12 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: def _assert_memory_footprint_increased_during_profiling(self): # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. - free_gpu_memory, total_memory = torch.cuda.mem_get_info() - memory_diff = self.init_gpu_memory - free_gpu_memory - - # If we've loaded model weights but memory shows no change, - # we're likely in a restricted environment - model_loaded = hasattr(self.model_runner, 'model') - memory_is_static = memory_diff == 0 - - is_restricted_env = model_loaded and memory_is_static - - if is_restricted_env: - logger.info("Detected restricted GPU environment. " - "Model is loaded but memory reports static usage. " - "Free memory: %.2fGB, Total memory: %.2fGB", - free_gpu_memory / (1024**3), - total_memory / (1024**3)) - - assert memory_diff > 0 or is_restricted_env, ( - "Error in memory profiling." + free_gpu_memory, _ = torch.cuda.mem_get_info() + assert self.init_gpu_memory - free_gpu_memory > 0, ( + "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks. From 1d976f26d520da0868a450edc0e703f5752b6425 Mon Sep 17 00:00:00 2001 From: Ahmed Mansy Date: Tue, 5 Nov 2024 00:26:09 +0200 Subject: [PATCH 8/8] [Lint] Adjust logging to be less than 80 Signed-off-by: Ahmed Mansy --- vllm/worker/worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index dd75daaebd5b4..effd3bf7a193e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -197,7 +197,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = max(num_cpu_blocks, 0) logger.info( - "Using num_gpu_blocks_override=%d, calculated num_cpu_blocks=%d", + "Using num_gpu_blocks_override=%d,calculated num_cpu_blocks=%d", num_gpu_blocks_override, num_cpu_blocks) return num_gpu_blocks_override, num_cpu_blocks @@ -271,6 +271,7 @@ def _assert_memory_footprint_increased_during_profiling(self): f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks.