diff --git a/docs/references/benchmark.md b/docs/references/benchmark.md index a443b7be8f..15295e2e50 100644 --- a/docs/references/benchmark.md +++ b/docs/references/benchmark.md @@ -66,34 +66,35 @@ docker run -d --name infinity -v $HOME/infinity:/var/infinity --ulimit nofile=50 4. Run Benchmark: -Drop file cache before benchmark query latency. +Drop file cache before benchmark. ```bash echo 3 | sudo tee /proc/sys/vm/drop_caches ``` Tasks of the Python script `run.py` include: - - Delete the original data. - - Re-insert the data. - - Calculate the time to insert data and build index. - - Calculate query latency. - - Calculate QPS. + - Generate fulltext query set. + - Measure the time to import data and build index. + - Measure the query latency. + - Measure the QPS. ```bash $ python run.py -h -usage: run.py [-h] [--generate] [--import] [--query] [--query-express QUERY_EXPRESS] [--engine ENGINE] [--dataset DATASET] +usage: run.py [-h] [--generate] [--import] [--query QUERY] [--query-express QUERY_EXPRESS] [--concurrency CONCURRENCY] [--engine ENGINE] [--dataset DATASET] RAG Database Benchmark options: - -h, --help show this help message and exit - --generate Generate fulltext queries based on the dataset - --import Import data set into database engine - --query Run single client to benchmark query latency - --query-express QUERY_EXPRESS - Run multiple clients in express mode to benchmark QPS - --engine ENGINE database engine to benchmark, one of: all, infinity, qdrant, elasticsearch - --dataset DATASET data set to benchmark, one of: all, gist, sift, geonames, enwiki +-h, --help show this help message and exit +--generate Generate fulltext query set based on the dataset (default: False) +--import Import dataset into database engine (default: False) +--query QUERY Run the query set only once using given number of clients with recording the result and latency. This is for result validation and latency analysis (default: 0) +--query-express QUERY_EXPRESS +Run the query set randomly using given number of clients without recording the result and latency. This is for QPS measurement. (default: 0) +--concurrency CONCURRENCY +Choose concurrency mechanism, one of: mp - multiprocessing(recommended), mt - multithreading. (default: mp) +--engine ENGINE Choose database engine to benchmark, one of: infinity, qdrant, elasticsearch (default: infinity) +--dataset DATASET Choose dataset to benchmark, one of: gist, sift, geonames, enwiki (default: enwiki) ``` Following are commands for engine `infinity` and dataset `enwiki`: diff --git a/python/benchmark/clients/base_client.py b/python/benchmark/clients/base_client.py index 4fc87d36f9..1dfb7a859c 100644 --- a/python/benchmark/clients/base_client.py +++ b/python/benchmark/clients/base_client.py @@ -9,6 +9,7 @@ import h5py import numpy as np import threading +import multiprocessing class BaseClient: @@ -25,11 +26,19 @@ def __init__(self, conf_path: str) -> None: self.data = None self.queries = list() self.clients = list() - self.lock = threading.Lock() - self.next_begin = 0 - self.results = [] - self.done_queries = 0 - self.active_threads = 0 + # Following are for multithreading + self.mt_lock = threading.Lock() + self.mt_next_begin = 0 + self.mt_done_queries = 0 + self.mt_active_workers = 0 + self.mt_results = [] + # Following are for multiprocessing + self.mp_manager = multiprocessing.Manager() + self.mp_lock = multiprocessing.Lock() + self.mp_next_begin = multiprocessing.Value("i", 0, lock=False) + self.mp_done_queries = multiprocessing.Value("i", 0, lock=False) + self.mp_active_workers = multiprocessing.Value("i", 0, lock=False) + self.mp_results = self.mp_manager.list() @abstractmethod def upload(self): @@ -39,7 +48,7 @@ def upload(self): pass @abstractmethod - def setup_clients(self, num_threads=1): + def setup_clients(self, num_workers=1): pass @abstractmethod @@ -60,8 +69,8 @@ def download_data(self, url, target_path): else: subprocess.run(["wget", "-O", target_path, url], check=True) - def search(self, is_express=False, num_threads=1): - self.setup_clients(num_threads) + def search_mt(self, is_express=False, num_workers=1): + self.setup_clients(num_workers) query_path = os.path.join(self.path_prefix, self.data["query_path"]) _, ext = os.path.splitext(query_path) @@ -81,9 +90,9 @@ def search(self, is_express=False, num_threads=1): query = json.loads(line)["vector"] self.queries.append(query) - self.active_threads = num_threads + self.mt_active_workers = num_workers threads = [] - for i in range(num_threads): + for i in range(num_workers): threads.append( threading.Thread( target=self.search_thread_mainloop, @@ -91,7 +100,7 @@ def search(self, is_express=False, num_threads=1): daemon=True, ) ) - for i in range(num_threads): + for i in range(num_workers): threads[i].start() report_qps_sec = 60 @@ -107,7 +116,7 @@ def search(self, is_express=False, num_threads=1): done_queries_prev = 0 done_queries_curr = 0 - while self.active_threads > 0: + while self.mt_active_workers > 0: time.sleep(sleep_sec) sleep_cnt += 1 if sleep_cnt < report_qps_sec / sleep_sec: @@ -115,8 +124,8 @@ def search(self, is_express=False, num_threads=1): sleep_cnt = 0 now = time.time() if done_warm_up: - with self.lock: - done_queries_curr = self.done_queries + with self.mt_lock: + done_queries_curr = self.mt_done_queries avg_start = done_queries_curr / (now - start) avg_interval = (done_queries_curr - done_queries_prev) / ( now - report_prev @@ -124,11 +133,11 @@ def search(self, is_express=False, num_threads=1): done_queries_prev = done_queries_curr report_prev = now logging.info( - f"average QPS since {start_str}: {avg_start}, average QPS of last interval:{avg_interval}" + f"average QPS since {start_str}: {avg_start}, average QPS of last interval: {avg_interval}" ) else: - with self.lock: - self.done_queries = 0 + with self.mt_lock: + self.mt_done_queries = 0 start = now start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start)) report_prev = now @@ -137,10 +146,10 @@ def search(self, is_express=False, num_threads=1): "Collecting statistics for 30 minutes. Print statistics so far every minute. Type Ctrl+C to quit." ) - for i in range(num_threads): + for i in range(num_workers): threads[i].join() if not is_express: - self.save_and_check_results(self.results) + self.save_and_check_results(self.mt_results) def search_thread_mainloop(self, is_express: bool, client_id: int): query_batch = 100 @@ -152,37 +161,159 @@ def search_thread_mainloop(self, is_express: bool, client_id: int): for i in range(query_batch): query_id = local_rng.randrange(0, num_queries) _ = self.do_single_query(query_id, client_id) - with self.lock: - self.done_queries += query_batch + with self.mt_lock: + self.mt_done_queries += query_batch else: begin = 0 end = 0 local_results = list() while end < num_queries: - with self.lock: - self.done_queries += end - begin - begin = self.next_begin + with self.mt_lock: + self.mt_done_queries += end - begin + begin = self.mt_next_begin end = begin + query_batch if end > num_queries: end = num_queries - self.next_begin = end + self.mt_next_begin = end for query_id in range(begin, end): start = time.time() result = self.do_single_query(query_id, client_id) latency = (time.time() - start) * 1000 result = [(query_id, latency)] + result local_results.append(result) - with self.lock: - self.done_queries += end - begin - self.results += local_results - with self.lock: - self.active_threads -= 1 + with self.mt_lock: + self.mt_done_queries += end - begin + self.mt_results += local_results + with self.mt_lock: + self.mt_active_workers -= 1 + + def search_mp(self, is_express=False, num_workers=1): + query_path = os.path.join(self.path_prefix, self.data["query_path"]) + _, ext = os.path.splitext(query_path) + if self.data["mode"] == "fulltext": + assert ext == ".txt" + for line in open(query_path, "r"): + line = line.strip() + self.queries.append(line) + else: + self.data["mode"] == "vector" + if ext == ".hdf5": + with h5py.File(query_path, "r") as f: + self.queries = list(f["test"]) + else: + assert ext == "jsonl" + for line in open(query_path, "r"): + query = json.loads(line)["vector"] + self.queries.append(query) + + self.mp_active_workers.value = num_workers + workers = [] + for i in range(num_workers): + workers.append( + multiprocessing.Process( + target=self.search_process_mainloop, + args=[is_express], + daemon=True, + ) + ) + for i in range(num_workers): + workers[i].start() + + report_qps_sec = 60 + sleep_sec = 10 + sleep_cnt = 0 + done_warm_up = True + if is_express: + logging.info(f"Let database warm-up for {report_qps_sec} seconds") + done_warm_up = False + start = time.time() + start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start)) + report_prev = start + done_queries_prev = 0 + done_queries_curr = 0 + + while True: + active_workers = 0 + with self.mp_lock: + active_workers = self.mp_active_workers.value + if active_workers <= 0: + break + time.sleep(sleep_sec) + sleep_cnt += 1 + if sleep_cnt < report_qps_sec / sleep_sec: + continue + sleep_cnt = 0 + now = time.time() + if done_warm_up: + with self.mp_lock: + done_queries_curr = self.mp_done_queries.value + avg_start = done_queries_curr / (now - start) + avg_interval = (done_queries_curr - done_queries_prev) / ( + now - report_prev + ) + done_queries_prev = done_queries_curr + report_prev = now + logging.info( + f"average QPS since {start_str}: {avg_start}, average QPS of last interval: {avg_interval}" + ) + else: + with self.mp_lock: + self.mp_done_queries.value = 0 + start = now + start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start)) + report_prev = now + done_warm_up = True + logging.info( + "Collecting statistics for 30 minutes. Print statistics so far every minute. Type Ctrl+C to quit." + ) + + for i in range(num_workers): + workers[i].join() + if not is_express: + self.save_and_check_results(self.mp_results) + + def search_process_mainloop(self, is_express: bool): + self.setup_clients(1) # socket is unsafe to share among workers + query_batch = 100 + num_queries = len(self.queries) + if is_express: + local_rng = random.Random() # random number generator per thread + deadline = time.time() + 30 * 60 # 30 minutes + while time.time() < deadline: + for i in range(query_batch): + query_id = local_rng.randrange(0, num_queries) + _ = self.do_single_query(query_id, 0) + with self.mp_lock: + self.mp_done_queries.value += query_batch + else: + begin = 0 + end = 0 + local_results = list() + while end < num_queries: + with self.mp_lock: + self.mp_done_queries.value += end - begin + begin = self.mp_next_begin.value + end = begin + query_batch + if end > num_queries: + end = num_queries + self.mp_next_begin.value = end + for query_id in range(begin, end): + start = time.time() + result = self.do_single_query(query_id, 0) + latency = (time.time() - start) * 1000 + result = [(query_id, latency)] + result + local_results.append(result) + with self.mp_lock: + self.mp_done_queries.value += end - begin + self.mp_results += local_results + with self.mp_lock: + self.mp_active_workers.value -= 1 def save_and_check_results(self, results: list[list[Any]]): """ Compare the search results with ground truth to calculate recall. """ - self.results.sort(key=lambda x: x[0][0]) + results = sorted(results, key=lambda x: x[0][0]) if "result_path" in self.data: result_path = self.data["result_path"] with open(result_path, "w") as f: @@ -242,7 +373,8 @@ def run_experiment(self, args): self.upload() finish_time = time.time() logging.info(f"upload finish, cost time = {finish_time - start_time}") - elif args.query >= 1: - self.search(is_express=False, num_threads=args.query) - elif args.query_express >= 1: - self.search(is_express=True, num_threads=args.query_express) + elif args.query >= 1 or args.query_express >= 1: + is_express = True if args.query_express >= 1 else False + search_func = self.search_mp if args.concurrency == "mp" else self.search_mt + num_workers = max(args.query, args.query_express) + search_func(is_express, num_workers) diff --git a/python/benchmark/run.py b/python/benchmark/run.py index 3e2af70c84..cb8175949a 100644 --- a/python/benchmark/run.py +++ b/python/benchmark/run.py @@ -14,13 +14,14 @@ def parse_args() -> argparse.Namespace: parser: argparse.ArgumentParser = argparse.ArgumentParser( - description="RAG Database Benchmark" + description="RAG Database Benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--generate", action="store_true", dest="generate_queries", - help="Generate fulltext queries based on the dataset", + help="Generate fulltext query set based on the dataset", ) parser.add_argument( "--import", @@ -40,21 +41,28 @@ def parse_args() -> argparse.Namespace: type=int, default=0, dest="query_express", - help="Run the query set randomly using given number of clients without recording the result. This is for QPS measurement.", + help="Run the query set randomly using given number of clients without recording the result and latency. This is for QPS measurement.", + ) + parser.add_argument( + "--concurrency", + type=str, + default="mp", + dest="concurrency", + help="Choose concurrency mechanism, one of: mp - multiprocessing(recommended), mt - multithreading.", ) parser.add_argument( "--engine", type=str, default="infinity", dest="engine", - help="database engine to benchmark, one of: " + ", ".join(ENGINES), + help="Choose database engine to benchmark, one of: " + ", ".join(ENGINES), ) parser.add_argument( "--dataset", type=str, default="enwiki", dest="dataset", - help="dataset to benchmark, one of: " + ", ".join(DATA_SETS), + help="Choose dataset to benchmark, one of: " + ", ".join(DATA_SETS), ) return parser.parse_args()