From b28c81841d41e4674033a49cec55a87273071936 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 15 Jun 2020 16:55:14 +0800 Subject: [PATCH 01/10] use async interface --- python/paddle_edl/distill/distill_reader.py | 37 ++- python/paddle_edl/distill/distill_worker.py | 240 ++++++++++++++++++++ 2 files changed, 276 insertions(+), 1 deletion(-) diff --git a/python/paddle_edl/distill/distill_reader.py b/python/paddle_edl/distill/distill_reader.py index 55ef0c33..541ded9d 100644 --- a/python/paddle_edl/distill/distill_reader.py +++ b/python/paddle_edl/distill/distill_reader.py @@ -177,6 +177,40 @@ def _start_predict_worker(self): process.append(worker) return process + def _start_predict_process(self): + if not self._is_predict_start: + stop_event = mps.Event() + predict_process = mps.Process( + target=distill_worker.predict_process, + args=(self._predict_server_queue, + self._predict_server_result_queue, + self._reader_out_queue, self._predict_out_queue, + self._feeds, self._fetchs, self._serving_conf_file, + stop_event, self._predict_cond)) + predict_process.daemon = True + predict_process.start() + + self._predict_manage_stop_event = threading.Event() + self._predict_manage_thread = threading.Thread( + target=distill_worker.predict_manage_worker, + args=( + [predict_process], + self._predict_server_queue, + self._predict_server_result_queue, + self._require_num, + self._predict_stop_events, + self._get_servers, + self._predict_manage_stop_event, + self._predict_cond, )) + self._predict_manage_thread.daemon = True + self._predict_manage_thread.start() + + self._is_predict_start = True + else: + # wake up predict process + with self._predict_cond: + self._predict_cond.notify() + def _start_predict_worker_pool(self): if not self._is_predict_start: # start predict worker pool @@ -365,7 +399,8 @@ def __call__(self): # >>> there will only be thread 2 and the lock will be held forever. # So need to move start_predict_worker_pool to the end if we use logging in predict # manager thread, or for the sake of safety, don't use logging? - self._start_predict_worker_pool() + # self._start_predict_worker_pool() + self._start_predict_process() for data in distill_worker.fetch_out( self._reader_type, self._predict_out_queue, diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 9c9691d5..41cfb723 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -19,6 +19,7 @@ import six import sys import time +import threading from paddle_serving_client import Client from six.moves import queue @@ -46,6 +47,8 @@ class ServerItem(object): PENDING = 'pending' ERROR = 'error' FINISHED = 'finished' + ADD = 'add' + STOP = 'stop' def __init__(self, server_id, server, stop_event_id, state=PENDING): self.server_id = server_id @@ -315,6 +318,243 @@ def __del__(self): pass +class AsyncPredictClient(object): + def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3): + self.server = server + self._config_file = config_file + self._predict_feed_idxs = [] + self._predict_feed_shapes = dict() + self._predict_feed_size = dict() + self._feeds = feeds + self._fetchs = fetchs + self._max_failed_times = max_failed_times + self.client = None + self._has_predict = False + self.need_stop = False + self.future_count = 0 + logger.info((server, config_file, feeds, fetchs, max_failed_times)) + + def __lt__(self, other): + return self.future_count < other.future_count + + def connect(self): + """ connect success, return True, else return False""" + try: + client = MultiLangClient() + client.load_client_config(self._config_file) + client.connect([self.server]) + self.client = client + except Exception as e: + logger.error('Exception when connect server={}, Exception is:'. + format(str(self.server))) + logger.error(str(e)) + return False + + self._predict_feed_idxs = [] + self._predict_feed_shapes = dict() + self._predict_feed_size = dict() + for feed_idx, feed_name in enumerate(self._feeds): + if feed_name in self.client.get_feed_names(): + self._predict_feed_idxs.append(feed_idx) + self._predict_feed_shapes[feed_name] = tuple( + self.client.feed_shapes_[feed_name]) + self._predict_feed_size[feed_name] = reduce( + lambda x, y: x * y, self._predict_feed_shapes[feed_name]) + return True + + def _preprocess(self, feed_data): + """ feed_data(list). format e.g. [(img, label, img1, label1), (img, label, img1, label1)] + However, predict may only need (img, img1). + return [{'img': img, 'img1': img1}, {'img': img, 'img1': img1}] + """ + feed_map_list = [] + for batch_idx in range(len(feed_data)): + feed_map = dict() + for feed_idx in self._predict_feed_idxs: + feed_name = self._feeds[feed_idx] + feed_size = self._predict_feed_size[feed_name] + feed_shape = self._predict_feed_shapes[feed_name] + + data = feed_data[batch_idx][feed_idx] + if data.size == feed_size: + data = data.reshape(feed_shape) + + feed_map[feed_name] = data + feed_map_list.append(feed_map) + + logger.debug('predict feed_map_list len={}'.format(len(feed_map_list))) + return feed_map_list + + def _postprocess(self, fetch_map_list, batch_size): + """ fetch_map_list(map): format e.g. {'predict0': np[bsize, ..], 'predict1': np[bsize, ..]} + return [(predict0, predict1), (predict0, predict1)] + """ + predict_data = [tuple() for _ in range(batch_size)] + for fetch_name in self._fetchs: + batch_fetch_data = fetch_map_list[fetch_name] + for batch_idx, fetch_data in enumerate(batch_fetch_data): + predict_data[batch_idx] += (fetch_data, ) + return predict_data + + def predict(self, feed_data): + """ predict success, return (True, predict_data), + else return (False, None)""" + self._has_predict = True + feed_map_list = self._preprocess(feed_data) + future = self.client.predict( + feed=feed_map_list, fetch=self._fetchs, asyn=True) + self.future_count += 1 + return future + + def result(self, future, feed_data): + fetch_map_list = None + try: + fetch_map_list = future.result() + except Exception as e: + logger.warning('Failed with server={}'.format(self.server)) + logger.warning('Exception:\n{}'.format(str(e))) + + self.future_count -= 1 + + if fetch_map_list is None: + return False, None + + predict_data = self._postprocess(fetch_map_list, len(feed_data)) + return True, predict_data + + +class _TestNopAsyncPredictClient(AsyncPredictClient): + class _Future(object): + def __init__(self, feed_data): + self._feed_data = feed_data + + def add_done_callback(self, call_back): + call_back(self) + + def connect(self): + return True + + def predict(self, feed_data): + return self._Future(feed_data) + + def result(self, future, feed_data): + predict_data = [tuple() for _ in range(len(feed_data))] + return True, predict_data + + +class PredictPool(object): + def __init__(self, server_result_queue, max_clients=1): + self._clients = queue.PriorityQueue() + self._server_to_clients = dict() + + self._server_result_queue = server_result_queue + + def add_client(self, server_item, feeds, fetchs, conf_file): + server = server_item.server + if server_item.server in self._server_to_clients: + return True + + predict_server = AsyncPredictClient if _NOP_PREDICT_TEST is False else _TestNopAsyncPredictClient + client = predict_server(server, conf_file, feeds, fetchs) + if not client.connect(): + return False + + self._server_to_clients[server] = (server_item, client) + self._clients.put(client) + return True + + def stop_client(self, server_item): + server = server_item.server + client = self._server_to_clients[server] + client.need_stop = True + + def rm_client(self, client): + server = client.server + server_item, client = self._server_to_clients[server] + del self._server_to_clients[server] + + server_item.state = ServerItem.FINISHED + self._server_result_queue.put(server_item) + + def run(self, in_queue, out_queue): + finished_task_count = [0] + task_count_lock = threading.Lock() + + while True: + data = in_queue.get() + if isinstance(data, _PoisonPill): + poison_pill = data + if finished_task_count[0] == poison_pill.feed_count: + poison_pill.predict_count = poison_pill.feed_count + out_queue.put(poison_pill) + break # all task finished + + in_queue.put(poison_pill) # write back poison pill + continue # continue process failed task + + task, read_data = data + while True: + client = self._clients.get() + if client.need_stop: + self.rm_client(client) + continue + + def predict_call_back(call_future): + success, predict_data = client.result(call_future, + read_data) + if not success: + in_queue.put(data) + client.need_stop = True # FIXME. stop? + return + + out_data = read_data + for i in range(len(out_data)): + out_data[i] += predict_data[i] + out_queue.put((task, out_data)) + with task_count_lock: + finished_task_count[0] += 1 + + future = client.predict(read_data) + future.add_done_callback(predict_call_back) + + self._clients.put(client) + break + + +def predict_process(server_queue, server_result_queue, in_queue, out_queue, + feeds, fetchs, conf_file, stop_event, predict_cond): + client_pool = PredictPool(server_result_queue) + manager_need_stop = False + + def server_manager(): + while not manager_need_stop: + try: + server_item = server_queue.get(timeout=1) + except queue.Empty: + continue + + if server_item is None: + server_result_queue.put(None) + return + + if server_item.state == ServerItem.PENDING: + client_pool.add_client(server_item, feeds, fetchs, conf_file) + elif server_item.state == ServerItem.STOP: + client_pool.stop_client(server_item) + + manage_thread = threading.Thread(target=server_manager, ) + manage_thread.daemon = True + manage_thread.start() + + while not stop_event.is_set(): + client_pool.run(in_queue, out_queue) + with predict_cond: + predict_cond.wait() + + manager_need_stop = True + manage_thread.join() + + def predict_worker(server_queue, server_result_queue, working_predict_count, in_queue, out_queue, feeds, fetchs, conf_file, stop_events, predict_lock, global_finished_task, predict_cond): From f79c464e9fa2d2f861d0b683b4cdf3b3b16b79c6 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 15 Jun 2020 20:45:04 +0800 Subject: [PATCH 02/10] update distill worker --- python/paddle_edl/distill/distill_worker.py | 24 ++++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 41cfb723..90e32d69 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import logging import numpy as np import os @@ -21,7 +22,7 @@ import time import threading -from paddle_serving_client import Client +from paddle_serving_client import Client, MultiLangClient from six.moves import queue from six.moves import reduce from .timeline import _TimeLine @@ -499,23 +500,30 @@ def run(self, in_queue, out_queue): self.rm_client(client) continue - def predict_call_back(call_future): - success, predict_data = client.result(call_future, - read_data) + def predict_call_back( + call_future, + _client, + _data, ): + _task, _read_data = _data + success, predict_data = _client.result(call_future, + _read_data) if not success: in_queue.put(data) - client.need_stop = True # FIXME. stop? + _client.need_stop = True # FIXME. stop? return - out_data = read_data + out_data = _read_data for i in range(len(out_data)): out_data[i] += predict_data[i] - out_queue.put((task, out_data)) + out_queue.put((_task, out_data)) with task_count_lock: finished_task_count[0] += 1 + #logger.info('client={}'.format(client)) future = client.predict(read_data) - future.add_done_callback(predict_call_back) + future.add_done_callback( + functools.partial( + predict_call_back, _client=client, _data=data)) self._clients.put(client) break From dc1f8fee51e67b4e0904b326fb8bff00de5f34e2 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 16 Jun 2020 14:27:29 +0800 Subject: [PATCH 03/10] fix hang --- python/paddle_edl/distill/distill_worker.py | 24 +++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 90e32d69..8c0a68a9 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -423,6 +423,18 @@ def result(self, future, feed_data): predict_data = self._postprocess(fetch_map_list, len(feed_data)) return True, predict_data + def __del__(self): + try: + # avoid serving exit bug when hasn't predict + if self.client is not None and self._has_predict: + self.client.release() + except Exception as e: + logger.critical('Release client failed with server={}, ' + 'there may be an unknown error'.format( + self.server)) + logger.critical('Exception:\n{}'.format(str(e))) + logger.warning('Stopped predict server={}'.format(self.server)) + class _TestNopAsyncPredictClient(AsyncPredictClient): class _Future(object): @@ -445,7 +457,8 @@ def result(self, future, feed_data): class PredictPool(object): def __init__(self, server_result_queue, max_clients=1): - self._clients = queue.PriorityQueue() + #self._clients = queue.PriorityQueue() + self._clients = queue.Queue() self._server_to_clients = dict() self._server_result_queue = server_result_queue @@ -495,6 +508,7 @@ def run(self, in_queue, out_queue): task, read_data = data while True: + #logger.info('task_id={}'.format(task.task_id)) client = self._clients.get() if client.need_stop: self.rm_client(client) @@ -508,7 +522,7 @@ def predict_call_back( success, predict_data = _client.result(call_future, _read_data) if not success: - in_queue.put(data) + in_queue.put(_data) _client.need_stop = True # FIXME. stop? return @@ -519,7 +533,6 @@ def predict_call_back( with task_count_lock: finished_task_count[0] += 1 - #logger.info('client={}'.format(client)) future = client.predict(read_data) future.add_done_callback( functools.partial( @@ -546,7 +559,10 @@ def server_manager(): return if server_item.state == ServerItem.PENDING: - client_pool.add_client(server_item, feeds, fetchs, conf_file) + if not client_pool.add_client(server_item, feeds, fetchs, + conf_file): + server_item.state = ServerItem.FINISHED + server_result_queue.put(server_item) elif server_item.state == ServerItem.STOP: client_pool.stop_client(server_item) From 2b5fefcf0d5fc32a61a660ef3809fa00187ba36e Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 18 Jun 2020 11:26:18 +0800 Subject: [PATCH 04/10] limit serving client concurrent --- python/paddle_edl/distill/distill_reader.py | 4 +- python/paddle_edl/distill/distill_worker.py | 62 +++++++++++++++------ 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/python/paddle_edl/distill/distill_reader.py b/python/paddle_edl/distill/distill_reader.py index 541ded9d..4b678c3a 100644 --- a/python/paddle_edl/distill/distill_reader.py +++ b/python/paddle_edl/distill/distill_reader.py @@ -245,7 +245,7 @@ def _init_args(self): self._reader_out_queue = mps.Queue() self._reader_stop_event = mps.Event() self._reader_cond = mps.Condition() - self._task_semaphore = mps.Semaphore(2 * self._require_num + 2) + self._task_semaphore = mps.Semaphore(3 * self._require_num + 1) # predict self._predict_server_queue = mps.Queue(self._require_num) @@ -377,7 +377,7 @@ def print_config(self): 'teacher_service_name': self._service_name, 'reader_type': self._reader_type, } - for config, value in print_config.iteritems(): + for config, value in print_config.items(): print("%s: %s" % (config, value)) print("------------------------------------------------") diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 8c0a68a9..5e545378 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -332,6 +332,8 @@ def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3): self.client = None self._has_predict = False self.need_stop = False + + self.lock = None self.future_count = 0 logger.info((server, config_file, feeds, fetchs, max_failed_times)) @@ -361,6 +363,7 @@ def connect(self): self.client.feed_shapes_[feed_name]) self._predict_feed_size[feed_name] = reduce( lambda x, y: x * y, self._predict_feed_shapes[feed_name]) + self.lock = threading.Lock() return True def _preprocess(self, feed_data): @@ -404,7 +407,8 @@ def predict(self, feed_data): feed_map_list = self._preprocess(feed_data) future = self.client.predict( feed=feed_map_list, fetch=self._fetchs, asyn=True) - self.future_count += 1 + with self.lock: + self.future_count += 1 return future def result(self, future, feed_data): @@ -415,7 +419,8 @@ def result(self, future, feed_data): logger.warning('Failed with server={}'.format(self.server)) logger.warning('Exception:\n{}'.format(str(e))) - self.future_count -= 1 + with self.lock: + self.future_count -= 1 if fetch_map_list is None: return False, None @@ -456,11 +461,14 @@ def result(self, future, feed_data): class PredictPool(object): - def __init__(self, server_result_queue, max_clients=1): - #self._clients = queue.PriorityQueue() - self._clients = queue.Queue() + def __init__(self, server_result_queue, max_concurrent=3): + self._clients = queue.PriorityQueue() self._server_to_clients = dict() + self._client_num_lock = threading.Lock() + self._client_num = 0 + self._max_concurrent = max_concurrent + self._server_result_queue = server_result_queue def add_client(self, server_item, feeds, fetchs, conf_file): @@ -475,6 +483,8 @@ def add_client(self, server_item, feeds, fetchs, conf_file): self._server_to_clients[server] = (server_item, client) self._clients.put(client) + with self._client_num_lock: + self._client_num += 1 return True def stop_client(self, server_item): @@ -486,13 +496,18 @@ def rm_client(self, client): server = client.server server_item, client = self._server_to_clients[server] del self._server_to_clients[server] + with self._client_num_lock: + self._client_num -= 1 server_item.state = ServerItem.FINISHED self._server_result_queue.put(server_item) def run(self, in_queue, out_queue): + finished_task_count_lock = threading.Lock() finished_task_count = [0] - task_count_lock = threading.Lock() + running_task_count_lock = threading.Lock() + running_task_count = [0] + task_cond = threading.Condition() while True: data = in_queue.get() @@ -508,35 +523,46 @@ def run(self, in_queue, out_queue): task, read_data = data while True: - #logger.info('task_id={}'.format(task.task_id)) client = self._clients.get() if client.need_stop: self.rm_client(client) continue - def predict_call_back( - call_future, - _client, - _data, ): - _task, _read_data = _data - success, predict_data = _client.result(call_future, - _read_data) + with running_task_count_lock: + running_task_count[0] += 1 + # limit max concurrent of client + while running_task_count[0] > self._client_num * \ + self._max_concurrent: + with task_cond: + task_cond.wait() + + def predict_call_back(call_future, _predict_client, _in_data): + with running_task_count_lock: + running_task_count[0] -= 1 + with task_cond: + task_cond.notify() + + _task, _read_data = _in_data + success, predict_data = _predict_client.result(call_future, + _read_data) if not success: - in_queue.put(_data) - _client.need_stop = True # FIXME. stop? + in_queue.put(_in_data) + _predict_client.need_stop = True return out_data = _read_data for i in range(len(out_data)): out_data[i] += predict_data[i] out_queue.put((_task, out_data)) - with task_count_lock: + with finished_task_count_lock: finished_task_count[0] += 1 future = client.predict(read_data) future.add_done_callback( functools.partial( - predict_call_back, _client=client, _data=data)) + predict_call_back, + _predict_client=client, + _in_data=data)) self._clients.put(client) break From 5652bd518ff656bb779a3d0336eb74607c2a7f95 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 19 Jun 2020 14:18:53 +0800 Subject: [PATCH 05/10] clear code --- python/paddle_edl/distill/distill_reader.py | 48 --- python/paddle_edl/distill/distill_worker.py | 338 ++------------------ 2 files changed, 33 insertions(+), 353 deletions(-) diff --git a/python/paddle_edl/distill/distill_reader.py b/python/paddle_edl/distill/distill_reader.py index 4b678c3a..3df1b2c5 100644 --- a/python/paddle_edl/distill/distill_reader.py +++ b/python/paddle_edl/distill/distill_reader.py @@ -154,29 +154,6 @@ def _start_reader_worker(self): with self._reader_cond: self._reader_cond.notify() - def _start_predict_worker(self): - process = [] - for i in range(self._require_num): - worker = mps.Process( - target=distill_worker.predict_worker, - args=( - self._predict_server_queue, - self._predict_server_result_queue, - self._working_predict_count, - self._reader_out_queue, - self._predict_out_queue, - self._feeds, - self._fetchs, - self._serving_conf_file, - self._predict_stop_events, - self._predict_lock, - self._predict_finished_task, - self._predict_cond, )) - worker.daemon = True - worker.start() - process.append(worker) - return process - def _start_predict_process(self): if not self._is_predict_start: stop_event = mps.Event() @@ -211,31 +188,6 @@ def _start_predict_process(self): with self._predict_cond: self._predict_cond.notify() - def _start_predict_worker_pool(self): - if not self._is_predict_start: - # start predict worker pool - process = self._start_predict_worker() - self._predict_manage_stop_event = threading.Event() - self._predict_manage_thread = threading.Thread( - target=distill_worker.predict_manage_worker, - args=( - process, - self._predict_server_queue, - self._predict_server_result_queue, - self._require_num, - self._predict_stop_events, - self._get_servers, - self._predict_manage_stop_event, - self._predict_cond, )) - self._predict_manage_thread.daemon = True - self._predict_manage_thread.start() - - self._is_predict_start = True - else: - # wake up predict worker pool - with self._predict_cond: - self._predict_cond.notify_all() - def _init_args(self): if not self._is_args_init: self._init_conf_file_from_env() diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 5e545378..8e76177b 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -22,7 +22,8 @@ import time import threading -from paddle_serving_client import Client, MultiLangClient +# from paddle_serving_client import Client, MultiLangClient +from paddle_serving_client import Client from six.moves import queue from six.moves import reduce from .timeline import _TimeLine @@ -188,137 +189,6 @@ def predict(self, feed_data): raise NotImplementedError() -class PaddlePredictServer(PredictServer): - def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3): - self._server = server - self._config_file = config_file - self._predict_feed_idxs = [] - self._predict_feed_shapes = dict() - self._predict_feed_size = dict() - self._feeds = feeds - self._fetchs = fetchs - self._max_failed_times = max_failed_times - self.client = None - self._has_predict = False - logger.info((server, config_file, feeds, fetchs, max_failed_times)) - - self._time_line = _TimeLine() - - def connect(self): - """ connect success, return True, else return False""" - try: - client = Client() - client.load_client_config(self._config_file) - client.connect([self._server]) - self.client = client - except Exception as e: - logger.error('Exception when connect server={}, Exception is:'. - format(str(self._server))) - logger.error(str(e)) - return False - - self._predict_feed_idxs = [] - self._predict_feed_shapes = dict() - self._predict_feed_size = dict() - for feed_idx, feed_name in enumerate(self._feeds): - if feed_name in self.client.get_feed_names(): - self._predict_feed_idxs.append(feed_idx) - self._predict_feed_shapes[feed_name] = tuple( - self.client.feed_shapes_[feed_name]) - self._predict_feed_size[feed_name] = reduce( - lambda x, y: x * y, self._predict_feed_shapes[feed_name]) - return True - - def _preprocess(self, feed_data): - """ feed_data(list). format e.g. [(img, label, img1, label1), (img, label, img1, label1)] - However, predict may only need (img, img1). - return [{'img': img, 'img1': img1}, {'img': img, 'img1': img1}] - """ - feed_map_list = [] - for batch_idx in range(len(feed_data)): - feed_map = dict() - for feed_idx in self._predict_feed_idxs: - feed_name = self._feeds[feed_idx] - feed_size = self._predict_feed_size[feed_name] - feed_shape = self._predict_feed_shapes[feed_name] - - data = feed_data[batch_idx][feed_idx] - if data.size == feed_size: - data = data.reshape(feed_shape) - - feed_map[feed_name] = data - feed_map_list.append(feed_map) - - logger.debug('predict feed_map_list len={}'.format(len(feed_map_list))) - return feed_map_list - - def _postprocess(self, fetch_map_list, batch_size): - """ fetch_map_list(map): format e.g. {'predict0': np[bsize, ..], 'predict1': np[bsize, ..]} - return [(predict0, predict1), (predict0, predict1)] - """ - predict_data = [tuple() for _ in range(batch_size)] - for fetch_name in self._fetchs: - batch_fetch_data = fetch_map_list[fetch_name] - for batch_idx, fetch_data in enumerate(batch_fetch_data): - predict_data[batch_idx] += (fetch_data, ) - return predict_data - - def predict(self, feed_data): - """ predict success, return (True, predict_data), - else return (False, None)""" - self._has_predict = True - self._time_line.reset() - feed_map_list = self._preprocess(feed_data) - self._time_line.record('predict_preprocess') - - fetch_map_list = None - for i in range(self._max_failed_times): - try: - fetch_map_list = self.client.predict( - feed=feed_map_list, fetch=self._fetchs) - if fetch_map_list is None: - raise Exception('fetch_map_list should not be None') - break - except Exception as e: - logger.warning('Failed {} times with server={}'.format( - i + 1, self._server)) - logger.warning('Exception:\n{}'.format(str(e))) - # time.sleep(0.1 * (i + 1)) - - self._time_line.record('real_predict') - - if fetch_map_list is None: - return False, None - - predict_data = self._postprocess(fetch_map_list, len(feed_data)) - self._time_line.record('postprocess') - return True, predict_data - - def __del__(self): - try: - # avoid serving exit bug when hasn't predict - if self.client is not None and self._has_predict: - self.client.release() - except Exception as e: - logger.critical('Release client failed with server={}, ' - 'there may be an unknown error'.format( - self._server)) - logger.critical('Exception:\n{}'.format(str(e))) - logger.warning('Stopped predict server={}'.format(self._server)) - - -class _TestNopPaddlePredictServer(PaddlePredictServer): - def connect(self): - return True - - def predict(self, feed_data): - predict_data = [tuple() for _ in range(len(feed_data))] - return True, predict_data - - def __del__(self): - pass - - class AsyncPredictClient(object): def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3): self.server = server @@ -570,44 +440,6 @@ def predict_call_back(call_future, _predict_client, _in_data): def predict_process(server_queue, server_result_queue, in_queue, out_queue, feeds, fetchs, conf_file, stop_event, predict_cond): - client_pool = PredictPool(server_result_queue) - manager_need_stop = False - - def server_manager(): - while not manager_need_stop: - try: - server_item = server_queue.get(timeout=1) - except queue.Empty: - continue - - if server_item is None: - server_result_queue.put(None) - return - - if server_item.state == ServerItem.PENDING: - if not client_pool.add_client(server_item, feeds, fetchs, - conf_file): - server_item.state = ServerItem.FINISHED - server_result_queue.put(server_item) - elif server_item.state == ServerItem.STOP: - client_pool.stop_client(server_item) - - manage_thread = threading.Thread(target=server_manager, ) - manage_thread.daemon = True - manage_thread.start() - - while not stop_event.is_set(): - client_pool.run(in_queue, out_queue) - with predict_cond: - predict_cond.wait() - - manager_need_stop = True - manage_thread.join() - - -def predict_worker(server_queue, server_result_queue, working_predict_count, - in_queue, out_queue, feeds, fetchs, conf_file, stop_events, - predict_lock, global_finished_task, predict_cond): signal_exit = [False, ] # Define signal handler function @@ -624,148 +456,44 @@ def predict_signal_handle(signum, frame): signal.signal(signal.SIGTERM, predict_signal_handle) try: - while True: - # get server - server_item = server_queue.get() - if server_item is None: - server_result_queue.put(None) - return - - # predict - success = predict_loop(server_item, working_predict_count, - in_queue, out_queue, feeds, fetchs, - conf_file, stop_events, predict_lock, - global_finished_task, predict_cond) - - server_item.state = ServerItem.FINISHED if success else ServerItem.ERROR - server_result_queue.put(server_item) - logger.info('Stopped server={}'.format(server_item.server)) - except Exception as e: - if signal_exit[0] is True: - pass - else: - six.reraise(*sys.exc_info()) + client_pool = PredictPool(server_result_queue) + manager_need_stop = False + + def server_manager(): + while not manager_need_stop: + try: + server_item = server_queue.get(timeout=1) + except queue.Empty: + continue + if server_item is None: + server_result_queue.put(None) + return -def predict_loop(server_item, working_predict_count, in_queue, out_queue, - feeds, fetchs, conf_file, stop_events, predict_lock, - global_finished_task, predict_cond): - logger.info('connect server={}'.format(server_item.server)) - predict_server = PaddlePredictServer if _NOP_PREDICT_TEST is False else _TestNopPaddlePredictServer - client = predict_server(server_item.server, conf_file, feeds, fetchs) - if not client.connect(): - return False + if server_item.state == ServerItem.PENDING: + if not client_pool.add_client(server_item, feeds, fetchs, + conf_file): + server_item.state = ServerItem.FINISHED + server_result_queue.put(server_item) + elif server_item.state == ServerItem.STOP: + client_pool.stop_client(server_item) - stop_event = stop_events[server_item.stop_event_id] - with predict_lock: - working_predict_count.value += 1 - - time_line = _TimeLine() - finished_task = 0 - # predict loop - while not stop_event.is_set(): - data = in_queue.get() - time_line.record('get_data') - - # Poison - if isinstance(data, _PoisonPill): - poison_pill = data - all_worker_done = False - - with predict_lock: - # accumulate success predict task count - poison_pill.predict_count += finished_task - poison_pill.predict_count += global_finished_task.value - - # clean local and global finished task - finished_task = 0 - global_finished_task.value = 0 - - # last process - if working_predict_count.value == 1: - if poison_pill.predict_count == poison_pill.feed_count: - working_predict_count.value -= 1 - logger.debug('pid={} write poison to complete queue'. - format(os.getpid())) - all_worker_done = True - else: - # NOTE. some predict worker failed, - # there are still tasks that have not been processed. - assert poison_pill.predict_count < poison_pill.feed_count, \ - "if failed, predict_count={} must < feed_count={}".\ - format(poison_pill.predict_count, poison_pill.feed_count) - - in_queue.put(poison_pill) # write back poison pill - continue # continue process failed task - else: # not last process - logger.debug('pid={} write poison back to ready'.format( - os.getpid())) - assert poison_pill.predict_count <= poison_pill.feed_count, \ - "predict_count={} must <= feed_count={}".format(poison_pill.predict_count, - poison_pill.feed_count) - working_predict_count.value -= 1 + manage_thread = threading.Thread(target=server_manager, ) + manage_thread.daemon = True + manage_thread.start() + while not stop_event.is_set(): + client_pool.run(in_queue, out_queue) with predict_cond: - if all_worker_done is True: - out_queue.put(poison_pill) # poison consumer - else: - in_queue.put(poison_pill) # poison other predict worker - if stop_event.is_set(): - break - # wait next reader iter or last failed predict job predict_cond.wait() - with predict_lock: - # go on working - working_predict_count.value += 1 - continue - - success, out_data = client_predict(client, data) - time_line.record('predict') - - if not success: - with predict_lock: - global_finished_task.value += finished_task - in_queue.put(data) # write back failed task data - # last process - if working_predict_count.value == 1: - # NOTE. need notify other predict worker, or maybe deadlock - with predict_cond: - predict_cond.notify_all() - working_predict_count.value -= 1 - return False - - out_queue.put(out_data) - finished_task += 1 - time_line.record('put_data') - - # disconnect with server - with predict_lock: - global_finished_task.value += finished_task - # last process - if working_predict_count.value == 1: - # NOTE. need notify other predict worker, or maybe deadlock. - with predict_cond: - predict_cond.notify_all() - working_predict_count.value -= 1 - return True - - -def client_predict(client, data): - # read_data format e.g. [(img, label, img1, label1), (img, label, img1, label1)] - # predict_data format e.g. [(predict0, predict1), (predict0, predict1)] - # out_data = read_data + predict_data, will be - # [(img, label, img1, label1, predict0, predict1), - # (img, label, img1, label1, predict0, predict1)] - task, read_data = data - success, predict_data = client.predict(read_data) - if not success: - return False, None - - out_data = read_data - for i in range(len(out_data)): - out_data[i] += predict_data[i] - return True, (task, out_data) + manager_need_stop = True + manage_thread.join() + except Exception as e: + if signal_exit[0] is True: + pass + else: + six.reraise(*sys.exc_info()) class ReaderType(object): From b3b2df432b7eea3eb060a6ddadf61c196f367d7e Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 19 Jun 2020 16:04:45 +0800 Subject: [PATCH 06/10] manager remove server --- python/paddle_edl/distill/distill_reader.py | 2 +- python/paddle_edl/distill/distill_worker.py | 39 +++++++++------------ 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/python/paddle_edl/distill/distill_reader.py b/python/paddle_edl/distill/distill_reader.py index 3df1b2c5..c780d971 100644 --- a/python/paddle_edl/distill/distill_reader.py +++ b/python/paddle_edl/distill/distill_reader.py @@ -197,7 +197,7 @@ def _init_args(self): self._reader_out_queue = mps.Queue() self._reader_stop_event = mps.Event() self._reader_cond = mps.Condition() - self._task_semaphore = mps.Semaphore(3 * self._require_num + 1) + self._task_semaphore = mps.Semaphore(4 * self._require_num) # predict self._predict_server_queue = mps.Queue(self._require_num) diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 8e76177b..e294ac05 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -46,16 +46,14 @@ def _is_server_alive(server): class ServerItem(object): - PENDING = 'pending' + RUNNING = 'pending' ERROR = 'error' FINISHED = 'finished' - ADD = 'add' - STOP = 'stop' + STOPPING = 'stopping' - def __init__(self, server_id, server, stop_event_id, state=PENDING): + def __init__(self, server_id, server, state=RUNNING): self.server_id = server_id self.server = server - self.stop_event_id = stop_event_id self.state = state @@ -72,9 +70,6 @@ def shutdown_one_process(): server_id = 0 # not yet used server_to_item = dict() # server to server_item idle_predict_num = require_num - event_set = set() - for i in range(require_num): - event_set.add(i) # Fix the order of object destruction first_in = True @@ -94,10 +89,11 @@ def shutdown_one_process(): while len(rm_servers) != 0: server = rm_servers.pop() server_item = server_to_item[server] - stop_event_id = server_item.stop_event_id - # set stop event - if not predict_stop_events[stop_event_id].is_set(): - predict_stop_events[stop_event_id].set() + + # need stop + if server_item.state != ServerItem.STOPPING: + server_item.state = ServerItem.STOPPING + server_queue.put(server_item) logger.info('Removing server={}'.format(server)) # Add servers @@ -112,8 +108,7 @@ def shutdown_one_process(): continue idle_predict_num -= 1 - event_id = event_set.pop() - server_item = ServerItem(server_id, server, event_id) + server_item = ServerItem(server_id, server) server_queue.put(server_item) server_to_item[server] = server_item server_id += 1 @@ -122,13 +117,8 @@ def shutdown_one_process(): try: # server job stop, return back stop_event_id server_result_item = server_result_queue.get(timeout=2) - stop_event_id = server_result_item.stop_event_id - event_set.add(stop_event_id) del server_to_item[server_result_item.server] - # clear event - predict_stop_events[stop_event_id].clear() - # directly use count idle_predict_num += 1 assert idle_predict_num <= require_num, \ 'idle_predict_num={} must <= require_num={}'.format( @@ -359,8 +349,11 @@ def add_client(self, server_item, feeds, fetchs, conf_file): def stop_client(self, server_item): server = server_item.server - client = self._server_to_clients[server] - client.need_stop = True + item_client = self._server_to_clients.get(server) + # client may removed before this + if item_client is not None: + _, client = item_client + client.need_stop = True def rm_client(self, client): server = client.server @@ -470,12 +463,12 @@ def server_manager(): server_result_queue.put(None) return - if server_item.state == ServerItem.PENDING: + if server_item.state == ServerItem.RUNNING: if not client_pool.add_client(server_item, feeds, fetchs, conf_file): server_item.state = ServerItem.FINISHED server_result_queue.put(server_item) - elif server_item.state == ServerItem.STOP: + elif server_item.state == ServerItem.STOPPING: client_pool.stop_client(server_item) manage_thread = threading.Thread(target=server_manager, ) From c310bff369e61b14a387ac3800f25432f4b83a1c Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 19 Jun 2020 17:31:17 +0800 Subject: [PATCH 07/10] clean code --- python/paddle_edl/distill/distill_reader.py | 11 ----------- python/paddle_edl/distill/distill_worker.py | 12 +++++------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/python/paddle_edl/distill/distill_reader.py b/python/paddle_edl/distill/distill_reader.py index c780d971..ebd831f4 100644 --- a/python/paddle_edl/distill/distill_reader.py +++ b/python/paddle_edl/distill/distill_reader.py @@ -92,11 +92,7 @@ def __init__(self, ins, predicts): # predict worker args self._predict_server_queue = None self._predict_server_result_queue = None - self._working_predict_count = None self._predict_out_queue = None - self._predict_stop_events = None - self._predict_lock = None - self._predict_finished_task = None self._predict_cond = None # predict worker pool self._predict_manage_thread = None @@ -175,7 +171,6 @@ def _start_predict_process(self): self._predict_server_queue, self._predict_server_result_queue, self._require_num, - self._predict_stop_events, self._get_servers, self._predict_manage_stop_event, self._predict_cond, )) @@ -202,13 +197,7 @@ def _init_args(self): # predict self._predict_server_queue = mps.Queue(self._require_num) self._predict_server_result_queue = mps.Queue(self._require_num) - self._working_predict_count = mps.Value('i', 0, lock=False) self._predict_out_queue = mps.Queue() - self._predict_stop_events = [ - mps.Event() for i in range(self._require_num) - ] - self._predict_lock = mps.Lock() - self._predict_finished_task = mps.Value('i', 0, lock=False) self._predict_cond = mps.Condition() # fetch diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index e294ac05..a6191913 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -58,8 +58,8 @@ def __init__(self, server_id, server, state=RUNNING): def predict_manage_worker(process, server_queue, server_result_queue, - require_num, predict_stop_events, get_servers_fun, - stop_event, predict_cond): + require_num, get_servers_fun, stop_event, + predict_cond): """ thread that manage predict worker """ num_shutdown_process = [0] @@ -138,8 +138,6 @@ def clean_queue(data_queue): clean_queue(server_result_queue) with predict_cond: - for predict_stop_event in predict_stop_events: - predict_stop_event.set() predict_cond.notify_all() for i in range(require_num): @@ -378,8 +376,7 @@ def run(self, in_queue, out_queue): poison_pill = data if finished_task_count[0] == poison_pill.feed_count: poison_pill.predict_count = poison_pill.feed_count - out_queue.put(poison_pill) - break # all task finished + return poison_pill # all task finished in_queue.put(poison_pill) # write back poison pill continue # continue process failed task @@ -476,8 +473,9 @@ def server_manager(): manage_thread.start() while not stop_event.is_set(): - client_pool.run(in_queue, out_queue) + poison_pill = client_pool.run(in_queue, out_queue) with predict_cond: + out_queue.put(poison_pill) predict_cond.wait() manager_need_stop = True From 3acc5fd99138354134638394cf88233e5b816b41 Mon Sep 17 00:00:00 2001 From: WangXi Date: Fri, 19 Jun 2020 19:28:55 +0800 Subject: [PATCH 08/10] exit --- python/paddle_edl/distill/distill_reader.py | 15 ++++++---- python/paddle_edl/distill/distill_worker.py | 33 ++++----------------- 2 files changed, 15 insertions(+), 33 deletions(-) diff --git a/python/paddle_edl/distill/distill_reader.py b/python/paddle_edl/distill/distill_reader.py index ebd831f4..6c9838ee 100644 --- a/python/paddle_edl/distill/distill_reader.py +++ b/python/paddle_edl/distill/distill_reader.py @@ -90,10 +90,12 @@ def __init__(self, ins, predicts): self._task_semaphore = None # predict worker args + self._predict_worker = None self._predict_server_queue = None self._predict_server_result_queue = None self._predict_out_queue = None self._predict_cond = None + self._predict_stop_event = None # predict worker pool self._predict_manage_thread = None self._predict_manage_stop_event = None @@ -152,28 +154,27 @@ def _start_reader_worker(self): def _start_predict_process(self): if not self._is_predict_start: - stop_event = mps.Event() + self._predict_stop_event = mps.Event() predict_process = mps.Process( target=distill_worker.predict_process, args=(self._predict_server_queue, self._predict_server_result_queue, self._reader_out_queue, self._predict_out_queue, self._feeds, self._fetchs, self._serving_conf_file, - stop_event, self._predict_cond)) + self._predict_stop_event, self._predict_cond)) predict_process.daemon = True predict_process.start() + self._predict_worker = predict_process self._predict_manage_stop_event = threading.Event() self._predict_manage_thread = threading.Thread( target=distill_worker.predict_manage_worker, args=( - [predict_process], self._predict_server_queue, self._predict_server_result_queue, self._require_num, self._get_servers, - self._predict_manage_stop_event, - self._predict_cond, )) + self._predict_manage_stop_event, )) self._predict_manage_thread.daemon = True self._predict_manage_thread.start() @@ -359,6 +360,10 @@ def __del__(self): self._predict_manage_stop_event.set() + self._predict_stop_event.set() + with self._predict_cond: + self._predict_cond.notify() + for i in range(20): if self._reader_worker.is_alive() or \ self._predict_manage_thread.is_alive(): diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index a6191913..f0e7b455 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -57,16 +57,9 @@ def __init__(self, server_id, server, state=RUNNING): self.state = state -def predict_manage_worker(process, server_queue, server_result_queue, - require_num, get_servers_fun, stop_event, - predict_cond): +def predict_manage_worker(server_queue, server_result_queue, require_num, + get_servers_fun, stop_event): """ thread that manage predict worker """ - num_shutdown_process = [0] - - def shutdown_one_process(): - server_queue.put(None) - num_shutdown_process[0] += 1 - server_id = 0 # not yet used server_to_item = dict() # server to server_item idle_predict_num = require_num @@ -137,22 +130,6 @@ def clean_queue(data_queue): clean_queue(server_queue) clean_queue(server_result_queue) - with predict_cond: - predict_cond.notify_all() - - for i in range(require_num): - shutdown_one_process() - clean_queue(server_result_queue) - - for i in range(20): - shutdown_process = 0 - for p in process: - if not p.is_alive(): - shutdown_process += 1 - if shutdown_process == len(process): - break - time.sleep(1) - class _PoisonPill: def __init__(self, feed_count, predict_count=0): @@ -447,10 +424,10 @@ def predict_signal_handle(signum, frame): try: client_pool = PredictPool(server_result_queue) - manager_need_stop = False + manager_need_stop = threading.Event() def server_manager(): - while not manager_need_stop: + while not manager_need_stop.is_set(): try: server_item = server_queue.get(timeout=1) except queue.Empty: @@ -478,7 +455,7 @@ def server_manager(): out_queue.put(poison_pill) predict_cond.wait() - manager_need_stop = True + manager_need_stop.set() manage_thread.join() except Exception as e: if signal_exit[0] is True: From 039bb8d7884125ef37bdc16bf64c8efeca387bdf Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 23 Jun 2020 10:27:16 +0800 Subject: [PATCH 09/10] update --- python/paddle_edl/distill/distill_worker.py | 138 ++++++++++++++------ 1 file changed, 99 insertions(+), 39 deletions(-) diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index f0e7b455..8ce60bc8 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import gc import logging import numpy as np import os @@ -22,8 +23,8 @@ import time import threading -# from paddle_serving_client import Client, MultiLangClient -from paddle_serving_client import Client +from paddle_serving_client import MultiLangClient +#from paddle_serving_client import Client from six.moves import queue from six.moves import reduce from .timeline import _TimeLine @@ -179,7 +180,7 @@ def connect(self): """ connect success, return True, else return False""" try: client = MultiLangClient() - client.load_client_config(self._config_file) + #client.load_client_config(self._config_file) client.connect([self.server]) self.client = client except Exception as e: @@ -260,6 +261,11 @@ def result(self, future, feed_data): if fetch_map_list is None: return False, None + if fetch_map_list['status_code'] != 0: + logger.warning('Failed status code={}'.format(fetch_map_list[ + 'status_code'])) + return False, None + predict_data = self._postprocess(fetch_map_list, len(feed_data)) return True, predict_data @@ -298,6 +304,7 @@ def result(self, future, feed_data): class PredictPool(object): def __init__(self, server_result_queue, max_concurrent=3): self._clients = queue.PriorityQueue() + #self._clients = queue.Queue() self._server_to_clients = dict() self._client_num_lock = threading.Lock() @@ -306,6 +313,10 @@ def __init__(self, server_result_queue, max_concurrent=3): self._server_result_queue = server_result_queue + self.finished_task_count_lock = threading.Lock() + self.running_task_count_lock = threading.Lock() + self.task_cond = threading.Condition() + def add_client(self, server_item, feeds, fetchs, conf_file): server = server_item.server if server_item.server in self._server_to_clients: @@ -341,23 +352,43 @@ def rm_client(self, client): self._server_result_queue.put(server_item) def run(self, in_queue, out_queue): - finished_task_count_lock = threading.Lock() + finished_task_count_lock = self.finished_task_count_lock finished_task_count = [0] - running_task_count_lock = threading.Lock() + running_task_count_lock = self.running_task_count_lock running_task_count = [0] - task_cond = threading.Condition() + task_cond = self.task_cond while True: data = in_queue.get() if isinstance(data, _PoisonPill): poison_pill = data - if finished_task_count[0] == poison_pill.feed_count: + task_count = -1 + with finished_task_count_lock: + task_count = finished_task_count[0] + if task_count == poison_pill.feed_count: poison_pill.predict_count = poison_pill.feed_count return poison_pill # all task finished - in_queue.put(poison_pill) # write back poison pill + #time.sleep(0.005) + #logger.info('put poison pill back') + in_queue.put(data) # write back poison pill + #logger.info('put poison pill back ok') continue # continue process failed task + #if isinstance(data, _PoisonPill): + # poison_pill = data + # logger.info('ending ------------') + # all_count = finished_task_count[0] + # while all_count < poison_pill.feed_count: + # time.sleep(0.005) + # logger.info('wait 0.005') + # all_count = finished_task_count[0] + + # logger.info('ended ------------') + # assert finished_task_count[0] == poison_pill.feed_count + # poison_pill.predict_count = poison_pill.feed_count + # return poison_pill + task, read_data = data while True: client = self._clients.get() @@ -373,40 +404,66 @@ def run(self, in_queue, out_queue): with task_cond: task_cond.wait() - def predict_call_back(call_future, _predict_client, _in_data): - with running_task_count_lock: - running_task_count[0] -= 1 - with task_cond: - task_cond.notify() - - _task, _read_data = _in_data - success, predict_data = _predict_client.result(call_future, - _read_data) - if not success: - in_queue.put(_in_data) - _predict_client.need_stop = True - return - - out_data = _read_data - for i in range(len(out_data)): - out_data[i] += predict_data[i] - out_queue.put((_task, out_data)) - with finished_task_count_lock: - finished_task_count[0] += 1 - - future = client.predict(read_data) - future.add_done_callback( - functools.partial( - predict_call_back, - _predict_client=client, - _in_data=data)) + try: + future = client.predict(read_data) + except: + logger.info('?????????????') + + call_back = predict_call_back( + in_queue, out_queue, client, data, running_task_count_lock, + running_task_count, finished_task_count_lock, + finished_task_count, task_cond) + future.add_done_callback(call_back) self._clients.put(client) + #logger.info('garbage collector output is {}'.format(gc.get_stats())) + #gc.collect(0) break +def predict_call_back( + in_queue, + out_queue, + client, + data, + running_task_count_lock, + running_task_count, + finished_task_count_lock, + finished_task_count, + task_cond, ): + def _call_back(future): + with running_task_count_lock: + running_task_count[0] -= 1 + with task_cond: + task_cond.notify() + + task, read_data = data + success = False + try: + success, predict_data = client.result(future, read_data) + except Exception as e: + print('result error={}'.format(e)) + if not success: + try: + in_queue.put(data) + except Exception as e: + logger.info('failed inque error={}'.format(e)) + client.need_stop = True + return + + out_data = read_data + for i in range(len(out_data)): + out_data[i] += predict_data[i] + out_queue.put((task, out_data)) + with finished_task_count_lock: + finished_task_count[0] += 1 + + return _call_back + + def predict_process(server_queue, server_result_queue, in_queue, out_queue, feeds, fetchs, conf_file, stop_event, predict_cond): + logger.info('predict process pid={}'.format(os.getpid())) signal_exit = [False, ] # Define signal handler function @@ -449,6 +506,8 @@ def server_manager(): manage_thread.daemon = True manage_thread.start() + #gc.set_debug(gc.DEBUG_LEAK) + while not stop_event.is_set(): poison_pill = client_pool.run(in_queue, out_queue) with predict_cond: @@ -458,10 +517,11 @@ def server_manager(): manager_need_stop.set() manage_thread.join() except Exception as e: - if signal_exit[0] is True: - pass - else: - six.reraise(*sys.exc_info()) + #if signal_exit[0] is True: + # pass + #else: + print('error={}'.format(e)) + six.reraise(*sys.exc_info()) class ReaderType(object): From 93d2c55829d7406c95b6e566161b29234c580b43 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 23 Jun 2020 11:30:31 +0800 Subject: [PATCH 10/10] update, multi client --- python/paddle_edl/distill/distill_worker.py | 153 ++++++-------------- 1 file changed, 47 insertions(+), 106 deletions(-) diff --git a/python/paddle_edl/distill/distill_worker.py b/python/paddle_edl/distill/distill_worker.py index 8ce60bc8..e8c4a9ca 100644 --- a/python/paddle_edl/distill/distill_worker.py +++ b/python/paddle_edl/distill/distill_worker.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import gc import logging import numpy as np @@ -24,7 +23,7 @@ import threading from paddle_serving_client import MultiLangClient -#from paddle_serving_client import Client +# from paddle_serving_client import Client from six.moves import queue from six.moves import reduce from .timeline import _TimeLine @@ -156,7 +155,7 @@ def predict(self, feed_data): class AsyncPredictClient(object): - def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3): + def __init__(self, server, config_file, feeds, fetchs): self.server = server self._config_file = config_file self._predict_feed_idxs = [] @@ -164,17 +163,11 @@ def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3): self._predict_feed_size = dict() self._feeds = feeds self._fetchs = fetchs - self._max_failed_times = max_failed_times self.client = None self._has_predict = False self.need_stop = False - self.lock = None - self.future_count = 0 - logger.info((server, config_file, feeds, fetchs, max_failed_times)) - - def __lt__(self, other): - return self.future_count < other.future_count + logger.info((server, config_file, feeds, fetchs)) def connect(self): """ connect success, return True, else return False""" @@ -199,7 +192,6 @@ def connect(self): self.client.feed_shapes_[feed_name]) self._predict_feed_size[feed_name] = reduce( lambda x, y: x * y, self._predict_feed_shapes[feed_name]) - self.lock = threading.Lock() return True def _preprocess(self, feed_data): @@ -243,8 +235,6 @@ def predict(self, feed_data): feed_map_list = self._preprocess(feed_data) future = self.client.predict( feed=feed_map_list, fetch=self._fetchs, asyn=True) - with self.lock: - self.future_count += 1 return future def result(self, future, feed_data): @@ -255,9 +245,6 @@ def result(self, future, feed_data): logger.warning('Failed with server={}'.format(self.server)) logger.warning('Exception:\n{}'.format(str(e))) - with self.lock: - self.future_count -= 1 - if fetch_map_list is None: return False, None @@ -302,24 +289,15 @@ def result(self, future, feed_data): class PredictPool(object): - def __init__(self, server_result_queue, max_concurrent=3): - self._clients = queue.PriorityQueue() - #self._clients = queue.Queue() + def __init__(self): + self._client_queue = queue.Queue() self._server_to_clients = dict() - self._client_num_lock = threading.Lock() - self._client_num = 0 - self._max_concurrent = max_concurrent - - self._server_result_queue = server_result_queue - - self.finished_task_count_lock = threading.Lock() - self.running_task_count_lock = threading.Lock() - self.task_cond = threading.Condition() - - def add_client(self, server_item, feeds, fetchs, conf_file): + def add_client(self, server_item, feeds, fetchs, conf_file, concurrent=3): server = server_item.server if server_item.server in self._server_to_clients: + logger.warning('server={} in predict client?'.format( + server_item.server)) return True predict_server = AsyncPredictClient if _NOP_PREDICT_TEST is False else _TestNopAsyncPredictClient @@ -328,9 +306,8 @@ def add_client(self, server_item, feeds, fetchs, conf_file): return False self._server_to_clients[server] = (server_item, client) - self._clients.put(client) - with self._client_num_lock: - self._client_num += 1 + for _ in range(concurrent): + self._client_queue.put(client) return True def stop_client(self, server_item): @@ -341,81 +318,49 @@ def stop_client(self, server_item): _, client = item_client client.need_stop = True - def rm_client(self, client): + def rm_client(self, client, server_result_queue): server = client.server - server_item, client = self._server_to_clients[server] + item_client = self._server_to_clients.get(server) + # client already removed + if item_client is None: + return + server_item, client = item_client del self._server_to_clients[server] - with self._client_num_lock: - self._client_num -= 1 server_item.state = ServerItem.FINISHED - self._server_result_queue.put(server_item) + server_result_queue.put(server_item) - def run(self, in_queue, out_queue): - finished_task_count_lock = self.finished_task_count_lock + def run(self, in_queue, out_queue, server_result_queue): + finished_task_count_lock = threading.Lock() finished_task_count = [0] - running_task_count_lock = self.running_task_count_lock - running_task_count = [0] - task_cond = self.task_cond while True: data = in_queue.get() if isinstance(data, _PoisonPill): poison_pill = data - task_count = -1 - with finished_task_count_lock: - task_count = finished_task_count[0] - if task_count == poison_pill.feed_count: + if finished_task_count[0] == poison_pill.feed_count: poison_pill.predict_count = poison_pill.feed_count return poison_pill # all task finished - #time.sleep(0.005) - #logger.info('put poison pill back') in_queue.put(data) # write back poison pill - #logger.info('put poison pill back ok') + time.sleep(0.003) # wait 3ms continue # continue process failed task - #if isinstance(data, _PoisonPill): - # poison_pill = data - # logger.info('ending ------------') - # all_count = finished_task_count[0] - # while all_count < poison_pill.feed_count: - # time.sleep(0.005) - # logger.info('wait 0.005') - # all_count = finished_task_count[0] - - # logger.info('ended ------------') - # assert finished_task_count[0] == poison_pill.feed_count - # poison_pill.predict_count = poison_pill.feed_count - # return poison_pill - task, read_data = data while True: - client = self._clients.get() + client = self._client_queue.get() if client.need_stop: - self.rm_client(client) + self.rm_client(client, server_result_queue) continue - with running_task_count_lock: - running_task_count[0] += 1 - # limit max concurrent of client - while running_task_count[0] > self._client_num * \ - self._max_concurrent: - with task_cond: - task_cond.wait() + # FIXME. may failed + future = client.predict(read_data) - try: - future = client.predict(read_data) - except: - logger.info('?????????????') - - call_back = predict_call_back( - in_queue, out_queue, client, data, running_task_count_lock, - running_task_count, finished_task_count_lock, - finished_task_count, task_cond) + call_back = predict_call_back(in_queue, out_queue, client, + data, finished_task_count_lock, + finished_task_count, + self._client_queue) future.add_done_callback(call_back) - - self._clients.put(client) #logger.info('garbage collector output is {}'.format(gc.get_stats())) #gc.collect(0) break @@ -426,34 +371,29 @@ def predict_call_back( out_queue, client, data, - running_task_count_lock, - running_task_count, finished_task_count_lock, finished_task_count, - task_cond, ): + client_queue, ): def _call_back(future): - with running_task_count_lock: - running_task_count[0] -= 1 - with task_cond: - task_cond.notify() - task, read_data = data + batch_size = len(read_data) success = False try: success, predict_data = client.result(future, read_data) except Exception as e: - print('result error={}'.format(e)) + logger.info('predict error={}'.format(e)) if not success: - try: - in_queue.put(data) - except Exception as e: - logger.info('failed inque error={}'.format(e)) + in_queue.put(data) client.need_stop = True + + client_queue.put(client) # complete, put back + if not success: return out_data = read_data - for i in range(len(out_data)): + for i in range(batch_size): out_data[i] += predict_data[i] + out_queue.put((task, out_data)) with finished_task_count_lock: finished_task_count[0] += 1 @@ -480,13 +420,13 @@ def predict_signal_handle(signum, frame): signal.signal(signal.SIGTERM, predict_signal_handle) try: - client_pool = PredictPool(server_result_queue) + client_pool = PredictPool() manager_need_stop = threading.Event() def server_manager(): while not manager_need_stop.is_set(): try: - server_item = server_queue.get(timeout=1) + server_item = server_queue.get(timeout=2) except queue.Empty: continue @@ -509,7 +449,8 @@ def server_manager(): #gc.set_debug(gc.DEBUG_LEAK) while not stop_event.is_set(): - poison_pill = client_pool.run(in_queue, out_queue) + poison_pill = client_pool.run(in_queue, out_queue, + server_result_queue) with predict_cond: out_queue.put(poison_pill) predict_cond.wait() @@ -517,11 +458,11 @@ def server_manager(): manager_need_stop.set() manage_thread.join() except Exception as e: - #if signal_exit[0] is True: - # pass - #else: - print('error={}'.format(e)) - six.reraise(*sys.exc_info()) + if signal_exit[0] is True: + pass + else: + print('error={}'.format(e)) + six.reraise(*sys.exc_info()) class ReaderType(object):