Skip to content

Commit

Permalink
Added back multiprocessing to benchmark scripts (#1246)
Browse files Browse the repository at this point in the history
Added back multiprocessing to benchmark scripts

- [x] Refactoring
  • Loading branch information
yuzhichang authored May 25, 2024
1 parent aef0348 commit 0f9b55f
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 55 deletions.
31 changes: 16 additions & 15 deletions docs/references/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,35 @@ docker run -d --name infinity -v $HOME/infinity:/var/infinity --ulimit nofile=50

4. Run Benchmark:

Drop file cache before benchmark query latency.
Drop file cache before benchmark.

```bash
echo 3 | sudo tee /proc/sys/vm/drop_caches
```

Tasks of the Python script `run.py` include:
- Delete the original data.
- Re-insert the data.
- Calculate the time to insert data and build index.
- Calculate query latency.
- Calculate QPS.
- Generate fulltext query set.
- Measure the time to import data and build index.
- Measure the query latency.
- Measure the QPS.

```bash
$ python run.py -h
usage: run.py [-h] [--generate] [--import] [--query] [--query-express QUERY_EXPRESS] [--engine ENGINE] [--dataset DATASET]
usage: run.py [-h] [--generate] [--import] [--query QUERY] [--query-express QUERY_EXPRESS] [--concurrency CONCURRENCY] [--engine ENGINE] [--dataset DATASET]

RAG Database Benchmark

options:
-h, --help show this help message and exit
--generate Generate fulltext queries based on the dataset
--import Import data set into database engine
--query Run single client to benchmark query latency
--query-express QUERY_EXPRESS
Run multiple clients in express mode to benchmark QPS
--engine ENGINE database engine to benchmark, one of: all, infinity, qdrant, elasticsearch
--dataset DATASET data set to benchmark, one of: all, gist, sift, geonames, enwiki
-h, --help show this help message and exit
--generate Generate fulltext query set based on the dataset (default: False)
--import Import dataset into database engine (default: False)
--query QUERY Run the query set only once using given number of clients with recording the result and latency. This is for result validation and latency analysis (default: 0)
--query-express QUERY_EXPRESS
Run the query set randomly using given number of clients without recording the result and latency. This is for QPS measurement. (default: 0)
--concurrency CONCURRENCY
Choose concurrency mechanism, one of: mp - multiprocessing(recommended), mt - multithreading. (default: mp)
--engine ENGINE Choose database engine to benchmark, one of: infinity, qdrant, elasticsearch (default: infinity)
--dataset DATASET Choose dataset to benchmark, one of: gist, sift, geonames, enwiki (default: enwiki)
```

Following are commands for engine `infinity` and dataset `enwiki`:
Expand Down
202 changes: 167 additions & 35 deletions python/benchmark/clients/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import h5py
import numpy as np
import threading
import multiprocessing


class BaseClient:
Expand All @@ -25,11 +26,19 @@ def __init__(self, conf_path: str) -> None:
self.data = None
self.queries = list()
self.clients = list()
self.lock = threading.Lock()
self.next_begin = 0
self.results = []
self.done_queries = 0
self.active_threads = 0
# Following are for multithreading
self.mt_lock = threading.Lock()
self.mt_next_begin = 0
self.mt_done_queries = 0
self.mt_active_workers = 0
self.mt_results = []
# Following are for multiprocessing
self.mp_manager = multiprocessing.Manager()
self.mp_lock = multiprocessing.Lock()
self.mp_next_begin = multiprocessing.Value("i", 0, lock=False)
self.mp_done_queries = multiprocessing.Value("i", 0, lock=False)
self.mp_active_workers = multiprocessing.Value("i", 0, lock=False)
self.mp_results = self.mp_manager.list()

@abstractmethod
def upload(self):
Expand All @@ -39,7 +48,7 @@ def upload(self):
pass

@abstractmethod
def setup_clients(self, num_threads=1):
def setup_clients(self, num_workers=1):
pass

@abstractmethod
Expand All @@ -60,8 +69,8 @@ def download_data(self, url, target_path):
else:
subprocess.run(["wget", "-O", target_path, url], check=True)

def search(self, is_express=False, num_threads=1):
self.setup_clients(num_threads)
def search_mt(self, is_express=False, num_workers=1):
self.setup_clients(num_workers)

query_path = os.path.join(self.path_prefix, self.data["query_path"])
_, ext = os.path.splitext(query_path)
Expand All @@ -81,17 +90,17 @@ def search(self, is_express=False, num_threads=1):
query = json.loads(line)["vector"]
self.queries.append(query)

self.active_threads = num_threads
self.mt_active_workers = num_workers
threads = []
for i in range(num_threads):
for i in range(num_workers):
threads.append(
threading.Thread(
target=self.search_thread_mainloop,
args=[is_express, i],
daemon=True,
)
)
for i in range(num_threads):
for i in range(num_workers):
threads[i].start()

report_qps_sec = 60
Expand All @@ -107,28 +116,28 @@ def search(self, is_express=False, num_threads=1):
done_queries_prev = 0
done_queries_curr = 0

while self.active_threads > 0:
while self.mt_active_workers > 0:
time.sleep(sleep_sec)
sleep_cnt += 1
if sleep_cnt < report_qps_sec / sleep_sec:
continue
sleep_cnt = 0
now = time.time()
if done_warm_up:
with self.lock:
done_queries_curr = self.done_queries
with self.mt_lock:
done_queries_curr = self.mt_done_queries
avg_start = done_queries_curr / (now - start)
avg_interval = (done_queries_curr - done_queries_prev) / (
now - report_prev
)
done_queries_prev = done_queries_curr
report_prev = now
logging.info(
f"average QPS since {start_str}: {avg_start}, average QPS of last interval:{avg_interval}"
f"average QPS since {start_str}: {avg_start}, average QPS of last interval: {avg_interval}"
)
else:
with self.lock:
self.done_queries = 0
with self.mt_lock:
self.mt_done_queries = 0
start = now
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
report_prev = now
Expand All @@ -137,10 +146,10 @@ def search(self, is_express=False, num_threads=1):
"Collecting statistics for 30 minutes. Print statistics so far every minute. Type Ctrl+C to quit."
)

for i in range(num_threads):
for i in range(num_workers):
threads[i].join()
if not is_express:
self.save_and_check_results(self.results)
self.save_and_check_results(self.mt_results)

def search_thread_mainloop(self, is_express: bool, client_id: int):
query_batch = 100
Expand All @@ -152,37 +161,159 @@ def search_thread_mainloop(self, is_express: bool, client_id: int):
for i in range(query_batch):
query_id = local_rng.randrange(0, num_queries)
_ = self.do_single_query(query_id, client_id)
with self.lock:
self.done_queries += query_batch
with self.mt_lock:
self.mt_done_queries += query_batch
else:
begin = 0
end = 0
local_results = list()
while end < num_queries:
with self.lock:
self.done_queries += end - begin
begin = self.next_begin
with self.mt_lock:
self.mt_done_queries += end - begin
begin = self.mt_next_begin
end = begin + query_batch
if end > num_queries:
end = num_queries
self.next_begin = end
self.mt_next_begin = end
for query_id in range(begin, end):
start = time.time()
result = self.do_single_query(query_id, client_id)
latency = (time.time() - start) * 1000
result = [(query_id, latency)] + result
local_results.append(result)
with self.lock:
self.done_queries += end - begin
self.results += local_results
with self.lock:
self.active_threads -= 1
with self.mt_lock:
self.mt_done_queries += end - begin
self.mt_results += local_results
with self.mt_lock:
self.mt_active_workers -= 1

def search_mp(self, is_express=False, num_workers=1):
query_path = os.path.join(self.path_prefix, self.data["query_path"])
_, ext = os.path.splitext(query_path)
if self.data["mode"] == "fulltext":
assert ext == ".txt"
for line in open(query_path, "r"):
line = line.strip()
self.queries.append(line)
else:
self.data["mode"] == "vector"
if ext == ".hdf5":
with h5py.File(query_path, "r") as f:
self.queries = list(f["test"])
else:
assert ext == "jsonl"
for line in open(query_path, "r"):
query = json.loads(line)["vector"]
self.queries.append(query)

self.mp_active_workers.value = num_workers
workers = []
for i in range(num_workers):
workers.append(
multiprocessing.Process(
target=self.search_process_mainloop,
args=[is_express],
daemon=True,
)
)
for i in range(num_workers):
workers[i].start()

report_qps_sec = 60
sleep_sec = 10
sleep_cnt = 0
done_warm_up = True
if is_express:
logging.info(f"Let database warm-up for {report_qps_sec} seconds")
done_warm_up = False
start = time.time()
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
report_prev = start
done_queries_prev = 0
done_queries_curr = 0

while True:
active_workers = 0
with self.mp_lock:
active_workers = self.mp_active_workers.value
if active_workers <= 0:
break
time.sleep(sleep_sec)
sleep_cnt += 1
if sleep_cnt < report_qps_sec / sleep_sec:
continue
sleep_cnt = 0
now = time.time()
if done_warm_up:
with self.mp_lock:
done_queries_curr = self.mp_done_queries.value
avg_start = done_queries_curr / (now - start)
avg_interval = (done_queries_curr - done_queries_prev) / (
now - report_prev
)
done_queries_prev = done_queries_curr
report_prev = now
logging.info(
f"average QPS since {start_str}: {avg_start}, average QPS of last interval: {avg_interval}"
)
else:
with self.mp_lock:
self.mp_done_queries.value = 0
start = now
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
report_prev = now
done_warm_up = True
logging.info(
"Collecting statistics for 30 minutes. Print statistics so far every minute. Type Ctrl+C to quit."
)

for i in range(num_workers):
workers[i].join()
if not is_express:
self.save_and_check_results(self.mp_results)

def search_process_mainloop(self, is_express: bool):
self.setup_clients(1) # socket is unsafe to share among workers
query_batch = 100
num_queries = len(self.queries)
if is_express:
local_rng = random.Random() # random number generator per thread
deadline = time.time() + 30 * 60 # 30 minutes
while time.time() < deadline:
for i in range(query_batch):
query_id = local_rng.randrange(0, num_queries)
_ = self.do_single_query(query_id, 0)
with self.mp_lock:
self.mp_done_queries.value += query_batch
else:
begin = 0
end = 0
local_results = list()
while end < num_queries:
with self.mp_lock:
self.mp_done_queries.value += end - begin
begin = self.mp_next_begin.value
end = begin + query_batch
if end > num_queries:
end = num_queries
self.mp_next_begin.value = end
for query_id in range(begin, end):
start = time.time()
result = self.do_single_query(query_id, 0)
latency = (time.time() - start) * 1000
result = [(query_id, latency)] + result
local_results.append(result)
with self.mp_lock:
self.mp_done_queries.value += end - begin
self.mp_results += local_results
with self.mp_lock:
self.mp_active_workers.value -= 1

def save_and_check_results(self, results: list[list[Any]]):
"""
Compare the search results with ground truth to calculate recall.
"""
self.results.sort(key=lambda x: x[0][0])
results = sorted(results, key=lambda x: x[0][0])
if "result_path" in self.data:
result_path = self.data["result_path"]
with open(result_path, "w") as f:
Expand Down Expand Up @@ -242,7 +373,8 @@ def run_experiment(self, args):
self.upload()
finish_time = time.time()
logging.info(f"upload finish, cost time = {finish_time - start_time}")
elif args.query >= 1:
self.search(is_express=False, num_threads=args.query)
elif args.query_express >= 1:
self.search(is_express=True, num_threads=args.query_express)
elif args.query >= 1 or args.query_express >= 1:
is_express = True if args.query_express >= 1 else False
search_func = self.search_mp if args.concurrency == "mp" else self.search_mt
num_workers = max(args.query, args.query_express)
search_func(is_express, num_workers)
Loading

0 comments on commit 0f9b55f

Please sign in to comment.