From 20eafe9fae7d55ecf5a4802b1c4480158e18f60f Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Fri, 28 Jun 2024 11:15:03 +0200 Subject: [PATCH] Add more metrics to high level profiler (#63) * Add more detailed event names to profiler * Add more profiler stats * separate prompt and decode batch utilization * Add more metrics * revert engine/metrics.py changes * un-singletonify (what a funny word) habana profiler * formatting * add batch block utilization metric * fix division by zero * fix batch_block_utilization formula * minor refactors --- vllm/worker/habana_model_runner.py | 82 ++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 10 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 6a9cb6f066ea1..1a9206a314d5c 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -241,6 +241,7 @@ def __init__( self.scheduler_config = scheduler_config self.lora_config = lora_config self.load_config = load_config + self.cache_config = cache_config self.is_driver_worker = is_driver_worker self.profiler = Profiler() @@ -267,6 +268,9 @@ def __init__( self.lora_manager: LRUCacheWorkerLoRAManager = None self.model: torch.nn.Module = None + # Profiler stats + self.profiler_counter_helper = HabanaProfilerCounterHelper() + self._setup_buckets() def load_model(self) -> None: @@ -876,19 +880,18 @@ def execute_model( output.outputs = output.outputs[:real_batch_size] htorch.core.mark_step() - if self.is_driver_worker: + if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event self.profiler.end() event_end = self.profiler.get_timestamp_us() - duration = event_end - event_start - throughput = batch_size_padded / (duration / 1e6) - throughput_effective = real_batch_size / (duration / 1e6) - counters = { - 'batch_size': batch_size_padded, - 'batch_size_effective': real_batch_size, - 'throughput': throughput, - 'throughput_effective': throughput_effective - } + counters = self.profiler_counter_helper.get_counter_dict( + cache_config=self.cache_config, + duration=event_end-event_start, + seq_len=seq_len, + batch_size_padded=batch_size_padded, + real_batch_size=real_batch_size, + seq_group_metadata_list=seq_group_metadata_list, + is_prompt=is_prompt) self.profiler.record_counter(event_start, counters) return output @@ -1014,3 +1017,62 @@ def vocab_size(self) -> int: def _maybe_wrap_in_hpu_graph(model): return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter(model)) if htorch.utils.internal.is_lazy() else HpuModelAdapter(model) + + +class HabanaProfilerCounterHelper(): + def __init__(self): + self.niter = 0 + self.average_real_throughput = None + self.logged_once = False + + def get_counter_dict(self, cache_config, duration, seq_len, batch_size_padded, real_batch_size, seq_group_metadata_list, is_prompt): + throughput = batch_size_padded / (duration / 1e6) + throughput_effective = real_batch_size / (duration / 1e6) + real_seq_lens = [len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids) for seq_group_metadata in seq_group_metadata_list for seq_data in seq_group_metadata.seq_data.values()] + real_max_seq_len = max(real_seq_lens) + real_num_tokens = sum(real_seq_lens) + padded_num_tokens = batch_size_padded * seq_len + batch_token_utilization = real_num_tokens / padded_num_tokens + if self.average_real_throughput is None: + self.average_real_throughput = throughput_effective + else: # https://www.heikohoffmann.de/htmlthesis/node134.html + self.average_real_throughput = self.average_real_throughput + 1/(self.niter+1) * (throughput_effective-self.average_real_throughput) + phase = "prompt" if is_prompt else "decode" + counters = { + f'{phase}_bucket_batch_size': batch_size_padded, + f'{phase}_batch_size': real_batch_size, + f'{phase}_bucket_seq_len': seq_len, + f'{phase}_seq_len': real_max_seq_len, + f'{phase}_bucket_gen_throughput': throughput, + f'{phase}_real_gen_throughput': throughput_effective, + f'{phase}_batch_token_utilization': batch_token_utilization, + 'average_real_throughput': self.average_real_throughput, + 'engine_iteration': self.niter, + } + self.niter += 1 + if is_prompt: + prompt_seq_lens = [len(seq_data.prompt_token_ids) for seq_group_metadata in seq_group_metadata_list for seq_data in seq_group_metadata.seq_data.values()] + prompt_bucket_in_throughput = (seq_len*batch_size_padded) / (duration / 1e6) + prompt_real_in_throughput = sum(prompt_seq_lens) / (duration / 1e6) + counters[f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput + counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput + + # KV cache might not be created yet (e.g. for profiling run) + if cache_config.num_gpu_blocks is not None and cache_config.num_gpu_blocks != 0: + cache_num_blocks_used = [math.ceil(sl/cache_config.block_size) for sl in real_seq_lens] + cache_total_num_blocks_used = sum(cache_num_blocks_used) + num_cache_blocks = cache_config.num_gpu_blocks + cache_total_num_free_blocks = num_cache_blocks - cache_total_num_blocks_used + cache_computed_utilization = cache_total_num_blocks_used / num_cache_blocks + max_blocks_per_seq = math.ceil(seq_len/cache_config.block_size) + batch_block_utilization = cache_total_num_blocks_used / (batch_size_padded * max_blocks_per_seq) + counters['cache_num_blocks_used'] = cache_total_num_blocks_used + counters['cache_num_free_blocks'] = cache_total_num_free_blocks + counters['cache_computed_utilization'] = cache_computed_utilization + counters[f'{phase}_batch_block_utilization'] = batch_block_utilization + if not self.logged_once: + counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks + counters['const_gpu_memory_utilization'] = cache_config.gpu_memory_utilization + counters['const_block_size'] = cache_config.block_size + self.logged_once = True + return counters