Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brpc multi client thread #123

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle_edl/distill/distill_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _init_args(self):
self._reader_out_queue = mps.Queue()
self._reader_stop_event = mps.Event()
self._reader_cond = mps.Condition()
self._task_semaphore = mps.Semaphore(2 * self._require_num + 2)
self._task_semaphore = mps.Semaphore(4 * self._require_num)

# predict
self._predict_server_queue = mps.Queue(self._require_num)
Expand Down Expand Up @@ -343,7 +343,7 @@ def print_config(self):
'teacher_service_name': self._service_name,
'reader_type': self._reader_type,
}
for config, value in print_config.iteritems():
for config, value in print_config.items():
print("%s: %s" % (config, value))
print("------------------------------------------------")

Expand Down
205 changes: 160 additions & 45 deletions python/paddle_edl/distill/distill_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import sys
import time

from collections import deque
from concurrent import futures
from paddle_serving_client import Client
from six.moves import queue
from six.moves import reduce
Expand Down Expand Up @@ -185,7 +187,7 @@ def predict(self, feed_data):


class PaddlePredictServer(PredictServer):
def __init__(self, server, config_file, feeds, fetchs, max_failed_times=3):
def __init__(self, server, config_file, feeds, fetchs, max_failed_times=2):
self._server = server
self._config_file = config_file
self._predict_feed_idxs = []
Expand Down Expand Up @@ -293,14 +295,15 @@ def predict(self, feed_data):
def __del__(self):
try:
# avoid serving exit bug when hasn't predict
if self.client is not None and self._has_predict:
self.client.release()
#if self.client is not None and self._has_predict:
# self.client.release()
pass
except Exception as e:
logger.critical('Release client failed with server={}, '
'there may be an unknown error'.format(
self._server))
logger.critical('Exception:\n{}'.format(str(e)))
logger.warning('Stopped predict server={}'.format(self._server))
#logger.warning('Stopped predict server={}'.format(self._server))


class _TestNopPaddlePredictServer(PaddlePredictServer):
Expand Down Expand Up @@ -334,22 +337,25 @@ def predict_signal_handle(signum, frame):
signal.signal(signal.SIGTERM, predict_signal_handle)

try:
while True:
# get server
server_item = server_queue.get()
if server_item is None:
server_result_queue.put(None)
return

# predict
success = predict_loop(server_item, working_predict_count,
in_queue, out_queue, feeds, fetchs,
conf_file, stop_events, predict_lock,
global_finished_task, predict_cond)

server_item.state = ServerItem.FINISHED if success else ServerItem.ERROR
server_result_queue.put(server_item)
logger.info('Stopped server={}'.format(server_item.server))
max_concurrent = 3
with futures.ThreadPoolExecutor(max_concurrent) as thread_pool:
while True:
# get server
server_item = server_queue.get()
if server_item is None:
server_result_queue.put(None)
return

# predict
success = predict_loop(server_item, working_predict_count,
in_queue, out_queue, feeds, fetchs,
conf_file, stop_events, predict_lock,
global_finished_task, predict_cond,
thread_pool, max_concurrent)

server_item.state = ServerItem.FINISHED if success else ServerItem.ERROR
server_result_queue.put(server_item)
logger.info('Stopped server={}'.format(server_item.server))
except Exception as e:
if signal_exit[0] is True:
pass
Expand All @@ -359,26 +365,119 @@ def predict_signal_handle(signum, frame):

def predict_loop(server_item, working_predict_count, in_queue, out_queue,
feeds, fetchs, conf_file, stop_events, predict_lock,
global_finished_task, predict_cond):
global_finished_task, predict_cond, thread_pool,
max_concurrent):
logger.info('connect server={}'.format(server_item.server))
predict_server = PaddlePredictServer if _NOP_PREDICT_TEST is False else _TestNopPaddlePredictServer
client = predict_server(server_item.server, conf_file, feeds, fetchs)
if not client.connect():
return False
idx = 0
clients = []
for _ in range(max_concurrent):
client = predict_server(server_item.server, conf_file, feeds, fetchs)
if not client.connect():
return False
clients.append(client)

tasks = deque(maxlen=max_concurrent)

stop_event = stop_events[server_item.stop_event_id]
with predict_lock:
working_predict_count.value += 1

time_line = _TimeLine()
finished_task = 0
delay = 0.0005 # 500us
# predict loop
while not stop_event.is_set():
data = in_queue.get()
if len(tasks) == max_concurrent:
# full, sync wait
task = tasks.popleft()
success, out_data = task.result()
if not success:
failed_datas = [out_data, ]
finished_task += process_remain_predict_data(
tasks, failed_datas, out_queue)

with predict_lock:
global_finished_task.value += finished_task
for failed_data in failed_datas:
in_queue.put(
failed_data) # write back failed task data
# last process
if working_predict_count.value == 1:
# NOTE. need notify other predict worker, or maybe deadlock
with predict_cond:
predict_cond.notify_all()
working_predict_count.value -= 1
return False

logger.debug('task_id={}'.format(out_data[0].task_id))
out_queue.put(out_data)
finished_task += 1
elif len(tasks) > 0:
# not full, query left
task = tasks.popleft()
if task.done():
success, out_data = task.result()
if not success:
failed_datas = [out_data, ]
finished_task += process_remain_predict_data(
tasks, failed_datas, out_queue)

with predict_lock:
global_finished_task.value += finished_task
for failed_data in failed_datas:
in_queue.put(
failed_data) # write back failed task data
# last process
if working_predict_count.value == 1:
# NOTE. need notify other predict worker, or maybe deadlock
with predict_cond:
predict_cond.notify_all()
working_predict_count.value -= 1
return False

logger.debug('task_id={}'.format(out_data[0].task_id))
out_queue.put(out_data)
finished_task += 1
else:
# not done, write back left
tasks.appendleft(task)

if len(tasks) == 0:
data = in_queue.get()
else:
# avoid hang
try:
data = in_queue.get(timeout=delay)
delay = 0.0005 # 500us
except queue.Empty:
delay = min(delay * 2, 0.032) # max 32ms
continue

time_line.record('get_data')

# Poison
if isinstance(data, _PoisonPill):
failed_datas = []
finished_task += process_remain_predict_data(tasks, failed_datas,
out_queue)

if len(failed_datas) != 0:
failed_datas.append(data)

with predict_lock:
global_finished_task.value += finished_task
for failed_data in failed_datas:
in_queue.put(
failed_data) # write back failed task data
# last process
if working_predict_count.value == 1:
# NOTE. need notify other predict worker, or maybe deadlock
with predict_cond:
predict_cond.notify_all()
working_predict_count.value -= 1
return False

poison_pill = data
all_worker_done = False

Expand Down Expand Up @@ -430,24 +529,26 @@ def predict_loop(server_item, working_predict_count, in_queue, out_queue,
working_predict_count.value += 1
continue

success, out_data = client_predict(client, data)
time_line.record('predict')

if not success:
with predict_lock:
global_finished_task.value += finished_task
in_queue.put(data) # write back failed task data
# last process
if working_predict_count.value == 1:
# NOTE. need notify other predict worker, or maybe deadlock
with predict_cond:
predict_cond.notify_all()
working_predict_count.value -= 1
return False

out_queue.put(out_data)
finished_task += 1
time_line.record('put_data')
client = clients[idx]
idx = (idx + 1) % max_concurrent
future = thread_pool.submit(client_predict, client, data)
tasks.append(future)

failed_datas = []
finished_task += process_remain_predict_data(tasks, failed_datas,
out_queue)
if len(failed_datas) != 0:
with predict_lock:
global_finished_task.value += finished_task
for failed_data in failed_datas:
in_queue.put(failed_data) # write back failed task data
# last process
if working_predict_count.value == 1:
# NOTE. need notify other predict worker, or maybe deadlock
with predict_cond:
predict_cond.notify_all()
working_predict_count.value -= 1
return False

# disconnect with server
with predict_lock:
Expand All @@ -461,6 +562,20 @@ def predict_loop(server_item, working_predict_count, in_queue, out_queue,
return True


def process_remain_predict_data(tasks, failed_datas, out_queue):
finished_task = 0
remain_task_count = len(tasks)
for i in range(remain_task_count):
task = tasks.popleft()
success, out_data = task.result()
if not success:
failed_datas.append(out_data)
else:
out_queue.put(out_data)
finished_task += 1
return finished_task


def client_predict(client, data):
# read_data format e.g. [(img, label, img1, label1), (img, label, img1, label1)]
# predict_data format e.g. [(predict0, predict1), (predict0, predict1)]
Expand All @@ -470,7 +585,7 @@ def client_predict(client, data):
task, read_data = data
success, predict_data = client.predict(read_data)
if not success:
return False, None
return False, data

out_data = read_data
for i in range(len(out_data)):
Expand Down Expand Up @@ -621,7 +736,7 @@ def read_batch(reader, teacher_batch_size, out_queue, task_semaphore):
for i in range(batch_size):
slot_data = tuple()
for j in range(slot_size):
slot_data += (read_data[j][i], )
slot_data += (np.asarray(read_data[j][i]), )
send_data.append(slot_data)

sample_size += 1
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
flask
pathlib2
futures; python_version == "2.7"
3 changes: 2 additions & 1 deletion python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ if os.getenv("PADDLE_EDL_VERSION"):
max_version, mid_version, min_version = python_version()

REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0', "flask", "pathlib2"
'six >= 1.10.0', 'protobuf >= 3.1.0', "flask", "pathlib2",
'futures; python_version == "2.7"'
]

packages=['paddle_edl',
Expand Down