Skip to content

Commit

Permalink
Improve monitoring metrics (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffwan authored Oct 25, 2024
1 parent a8ae12c commit bff7105
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
1 change: 1 addition & 0 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from backend_request_func import get_tokenizer

PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
PROMPT = "You are an AI programming assistant and your task is to generate a SQL query based on the input database schema and user questions.\n### Task Description:\nGiven the following database schema, please write a SQL query to answer the given question.\n\n### Schema:\nThe database contains 3 tables: ['customer', 'branch', 'customer_order'].\n\n- **Table**: customer\n\t- *Description*: The table customer has 5 columns: ['Customer_ID', 'Name', 'Nationality', 'Card_Credit', 'Level_of_Membership'].\n\t- *Primary Key*: Customer_ID\n\t- *Foreign Keys*: \n\t- *Column*: Customer_ID\n\t\t- Type: INT\n\t\t- Sampled Values: 1, 2, 3, 4, 5\n\t- *Column*: Name\n\t\t- Type: TEXT\n\t\t- Sampled Values: Arthur Morris, Bill Edrich, Cyril Washbrook, Denis Compton, Donald Bradman\n\t- *Column*: Nationality\n\t\t- Type: TEXT\n\t\t- Sampled Values: Australia, England\n\t- *Column*: Card_Credit\n\t\t- Type: REAL\n\t\t- Sampled Values: 31.9, 42.75, 44.28, 50.85, 62.44\n\t- *Column*: Level_of_Membership\n\t\t- Type: INT\n\t\t- Sampled Values: 0, 1, 2, 3\n\n- **Table**: branch\n\t- *Description*: The table branch has 4 columns: ['Branch_ID', 'Manager', 'Years_opened', 'Location_of_office'].\n\t- *Primary Key*: Branch_ID\n\t- *Foreign Keys*: \n\t- *Column*: Branch_ID\n\t\t- Type: INT\n\t\t- Sampled Values: 1, 2, 3, 4, 5\n\t- *Column*: Manager\n\t\t- Type: TEXT\n\t\t- Sampled Values: Ashby Lazale, Breton Robert, Campbell Jessie, Cobb Sedrick, Hayes Steven\n\t- *Column*: Years_opened\n\t\t- Type: INT\n\t\t- Sampled Values: 2, 3, 4, 5, 6\n\t- *Column*: Location_of_office\n\t\t- Type: TEXT\n\t\t- Sampled Values: Bridgeport, Cheshire, Hartford, Waterbury\n\n- **Table**: customer_order\n\t- *Description*: The table customer_order has 4 columns: ['Customer_ID', 'Branch_ID', 'Dish_Name', 'Quantity'].\n\t- *Primary Key*: ['Customer_ID', 'Branch_ID', 'Dish_Name']\n\t- *Foreign Keys*: customer_order.Customer_ID = customer.Customer_ID, customer_order.Branch_ID = branch.Branch_ID\n\t- *Column*: Customer_ID\n\t\t- Type: INT\n\t\t- Sampled Values: 1, 2, 3, 4, 5\n\t- *Column*: Branch_ID\n\t\t- Type: INT\n\t\t- Sampled Values: 6, 9, 10\n\t- *Column*: Dish_Name\n\t\t- Type: TEXT\n\t\t- Sampled Values: Chow Mein, Kung Pao Chicken, Ma Po Tofu, Peking Roasted Duck, Spring Rolls\n\t- *Column*: Quantity\n\t\t- Type: INT\n\t\t- Sampled Values: 1, 2, 4\n\n### Requirements:\n* Please first return the SQL query to answer the question and then explain your SQL query step by step.\n* Please generate your response using the following format:\n```sql\n<YOUR SQL QUERY>\n```\n<YOUR EXPLANATION>, where the SQL query is in a Markdown code block.\n* Provide a detailed explanation that reflects your reasoning process step by step. Specifically, explain each part of the SQL query (each clause, operator, etc.) and how they work together to answer the question step by step.\n* Please organize the explanation of each SQL step using a Markdown list, following this format:\n- <what is done in the 1st step>\n`<SQL clause used in this step>`\n- <what is done in the 2nd step>\n`<SQL clause used in this step>`\n- <what is done in the 3rd step>\n`<SQL clause used in this step>`...\n* If a certain step involves a nested subquery, provide a detailed explanation for each part of the subquery. You can explain the subquery using the following format:\n- <what is done in the sub-SQL>\n`<sub-SQL>`\n\t* <what is done in the 1st step in sub-SQL>\n\t`<SQL clause used in this step>`\n\t* <what is done in the 2nd step in sub-SQL>\n\t`<SQL clause used in this step>`...\n* When quoting parts of the SQL query in your explanation, please enclose the statement in single backticks, like this: `<part of SQL>`.\n* Please keep your explanation concise and clear within 100 words.\n* Please do NOT select extra columns that are not explicitly requested in the query.\n* Ensure that the table and column names in the generated query exactly match those in the schema. Do NOT include any columns or tables that are not present in the provided schema.\n* Please ensure that the SQL query remains concise and avoids unnecessary joins with unrelated tables.\n\n### Question:\nShow the most common nationality of customers.\n\n### Output:\n"


def test_prefix(llm=None, sampling_params=None, prompts=None):
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ def prepare_model_input(
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
kv_caches: List[torch.Tensor] = [],
) -> ModelInputForGPUWithSamplingMetadata:
) -> Tuple[ModelInputForGPUWithSamplingMetadata, Optional[Dict[str, int]]]:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
Expand Down
14 changes: 8 additions & 6 deletions vllm/worker/vineyard_llm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ def prefetch_seq_kv_caches(
query_tokens
) = query_args

start_time = time.time()
start_time = time.perf_counter()
matched = self.cache.query(
prefix=query_prefix,
tokens=query_tokens,
kv_cache_list=self.tensors[:query_token_size],
)
duration = time.time() - start_time
duration = time.perf_counter() - start_time
self.metrics.time_query.append(duration)
self.metrics.normalized_time_query.append(duration/len(tokens))
# synchronized across tensor parallel ranks
Expand All @@ -198,6 +198,8 @@ def prefetch_seq_kv_caches(
offset = context_len % self.chunk_size
matched -= offset

# we force to use token_chunk_size - 1 to trigger KV recomputation
# TODO: this should be revisited later. We are looking for solutions to fully avoid computation.
matched = min(matched, token_chunk_size - 1)
if matched <= 0:
return seq_id, 0
Expand Down Expand Up @@ -228,7 +230,7 @@ def prefetch_seq_kv_caches(
copy_start.record()
buffer = self.buffer[:, :, offset:offset+matched].cuda()
copy_end.record()
torch.cuda.synchronize()
copy_end.synchronize()
duration = copy_start.elapsed_time(copy_end) / 1000.0
self.metrics.time_load.append(duration)
self.metrics.normalized_time_load.append(0 if matched == 0 else duration/matched)
Expand Down Expand Up @@ -370,20 +372,20 @@ def update_seq_kv_caches(
for j in range(self.layer):
self.buffer[:, j, :update_token_size].copy_(
kv_caches[j][:, slot_mapping // block_size, slot_mapping % block_size])
torch.cuda.synchronize()
end_unload.record()
end_unload.synchronize()
duration = start_unload.elapsed_time(end_unload) / 1000.0
self.metrics.time_unload.append(duration)
self.metrics.normalized_time_unload.append(0 if update_token_size == 0 else duration/update_token_size)

start_time = time.time()
start_time = time.perf_counter()
# updates into vineyard
updated = self.cache.update(
prefix=update_prefix,
tokens=update_tokens,
kv_cache_list=self.tensors[:update_token_size],
)
duration = time.time() - start_time
duration = time.perf_counter() - start_time
self.metrics.time_update.append(duration)
self.metrics.normalized_time_update.append(0 if update_token_size == 0 else duration/update_token_size)

Expand Down
1 change: 1 addition & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def execute_model(
)

model_execute_time = time.perf_counter() - start_time
# TODO: make update_kv_caches async
if self.model_runner.vineyard_llm_cache and self.kv_cache[worker_input.virtual_engine][0] is not None:
self.model_runner.vineyard_llm_cache.update_kv_caches(
cache_hints,
Expand Down

0 comments on commit bff7105

Please sign in to comment.