Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: master: several fixes for batch async code #693

Open
wants to merge 11 commits into
base: openSUSE/devel/master
Choose a base branch
from
152 changes: 68 additions & 84 deletions salt/cli/batch_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re

import tornado
import asyncio

import salt.client
import salt.utils.event
Expand Down Expand Up @@ -67,8 +68,8 @@ def __init__(self, opts, io_loop):
keep_loop=True,
)
self.master_event.set_event_handler(self.__handle_event)
if self.master_event.subscriber.stream:
self.master_event.subscriber.stream.set_close_callback(self.__handle_close)
if self.master_event.subscriber._stream:
self.master_event.subscriber._stream.set_close_callback(self.__handle_close)
self._re_tag_ret_event = re.compile(r"salt\/job\/(\d+)\/ret\/.*")
self._subscribers = {}
self._subscriptions = {}
Expand All @@ -92,7 +93,7 @@ def subscribe(self, jid, op, subscriber_id, handler):
self._subscribers[subscriber_id].add(jid)
if (op, subscriber_id, handler) not in self._subscriptions[jid]:
self._subscriptions[jid].append((op, subscriber_id, handler))
if not self.master_event.subscriber.connected():
if not self.master_event.subscriber.connected:
self.__reconnect_subscriber()

def unsubscribe(self, jid, op, subscriber_id):
Expand All @@ -113,25 +114,24 @@ def unsubscribe(self, jid, op, subscriber_id):
if not self._subscribers[subscriber_id]:
del self._subscribers[subscriber_id]

@tornado.gen.coroutine
def __handle_close(self):
async def __handle_close(self):
if not self._subscriptions:
return
log.warning("Master Event Subscriber was closed. Trying to reconnect...")
yield self.__reconnect_subscriber()
await self.__reconnect_subscriber()

@tornado.gen.coroutine
def __handle_event(self, raw):
async def __handle_event(self, raw):
if self.master_event is None:
return
try:
tag, data = self.master_event.unpack(raw)
tag_match = self._re_tag_ret_event.match(tag)
log.trace("SharedEventsChannel.__handle_event -> {} - {}".format(tag, data))
if tag_match:
jid = tag_match.group(1)
if jid in self._subscriptions:
for op, _, handler in self._subscriptions[jid]:
yield handler(tag, data, op)
await handler(tag, data, op)
except Exception as ex: # pylint: disable=W0703
log.error(
"Exception occured while processing event: %s: %s",
Expand All @@ -140,9 +140,8 @@ def __handle_event(self, raw):
exc_info=True,
)

@tornado.gen.coroutine
def __reconnect_subscriber(self):
if self.master_event.subscriber.connected() or self._reconnecting_subscriber:
async def __reconnect_subscriber(self):
if self.master_event.subscriber.connected or self._reconnecting_subscriber:
return
self._reconnecting_subscriber = True
max_tries = max(1, int(self._subscriber_reconnect_tries))
Expand All @@ -154,22 +153,22 @@ def __reconnect_subscriber(self):
max_tries,
)
try:
yield self.master_event.subscriber.connect()
await self.master_event.subscriber.connect()
except StreamClosedError:
log.warning(
"Unable to reconnect to event publisher (try %d of %d)",
_try,
max_tries,
)
if self.master_event.subscriber.connected():
self.master_event.subscriber.stream.set_close_callback(
if self.master_event.subscriber.connected:
self.master_event.subscriber._stream.set_close_callback(
self.__handle_close
)
log.info("Event publisher connection restored")
self._reconnecting_subscriber = False
return
if _try < max_tries:
yield tornado.gen.sleep(self._subscriber_reconnect_interval)
await asyncio.sleep(self._subscriber_reconnect_interval)
_try += 1
self._reconnecting_subscriber = False

Expand All @@ -181,9 +180,9 @@ def unuse(self, subscriber_id):
self._used_by.discard(subscriber_id)

def destroy_unused(self):
log.trace("SharedEventsChannel.destroy_unused called")
if self._used_by:
return False
self.master_event.remove_event_handler(self.__handle_event)
self.master_event.destroy()
self.master_event = None
self.local_client.destroy()
Expand Down Expand Up @@ -278,24 +277,25 @@ def __set_event_handler(self):
self.batch_jid, "batch_run", id(self), self.__event_handler
)

@tornado.gen.coroutine
def __event_handler(self, tag, data, op):
async def __event_handler(self, tag, data, op):
# IMPORTANT: This function must run fast and not wait for any other task,
# otherwise it would cause events to be stuck.
log.trace("BatchAsync.__event_handler called with ({}, {}, {})".format(op, tag, data))
if not self.event:
return
try:
minion = data["id"]
if op == "ping_return":
self.minions.add(minion)
if self.targeted_minions == self.minions:
yield self.start_batch()
elif op == "find_job_return":
if data.get("return", None):
self.find_job_returned.add(minion)
elif op == "batch_run":
if minion in self.active:
self.active.remove(minion)
self.done_minions.add(minion)
yield self.schedule_next()
if not self.active:
asyncio.create_task(self.schedule_next())
except Exception as ex: # pylint: disable=W0703
log.error(
"Exception occured while processing event: %s: %s",
Expand All @@ -316,8 +316,7 @@ def _get_next(self):
)
return set(list(to_run)[:next_batch_size])

@tornado.gen.coroutine
def check_find_job(self, batch_minions, jid):
async def check_find_job(self, batch_minions, jid):
"""
Check if the job with specified ``jid`` was finished on the minions
"""
Expand All @@ -335,17 +334,17 @@ def check_find_job(self, batch_minions, jid):
)

if timedout_minions:
yield self.schedule_next()
asyncio.create_task(self.schedule_next())

if self.event and running:
self.find_job_returned = self.find_job_returned.difference(running)
yield self.find_job(running)
await self.find_job(running)

@tornado.gen.coroutine
def find_job(self, minions):
async def find_job(self, minions):
"""
Find if the job was finished on the minions
"""
log.trace("BatchAsync.find_job called for minions: {}".format(minions))
if not self.event:
return
not_done = minions.difference(self.done_minions).difference(
Expand All @@ -358,7 +357,7 @@ def find_job(self, minions):
self.events_channel.subscribe(
jid, "find_job_return", id(self), self.__event_handler
)
ret = yield self.events_channel.local_client.run_job_async(
await self.events_channel.local_client.run_job_async(
not_done,
"saltutil.find_job",
[self.batch_jid],
Expand All @@ -369,9 +368,9 @@ def find_job(self, minions):
listen=False,
**self.eauth,
)
yield tornado.gen.sleep(self.opts["gather_job_timeout"])
await asyncio.sleep(self.opts["gather_job_timeout"])
if self.event:
yield self.check_find_job(not_done, jid)
await self.check_find_job(not_done, jid)
except Exception as ex: # pylint: disable=W0703
log.error(
"Exception occured handling batch async: %s. Aborting execution.",
Expand All @@ -380,15 +379,14 @@ def find_job(self, minions):
)
self.close_safe()

@tornado.gen.coroutine
def start(self):
async def start(self):
"""
Start the batch execution
"""
if not self.event:
return
self.__set_event_handler()
ping_return = yield self.events_channel.local_client.run_job_async(
ping_return = await self.events_channel.local_client.run_job_async(
self.opts["tgt"],
"test.ping",
[],
Expand All @@ -402,40 +400,22 @@ def start(self):
)
self.targeted_minions = set(ping_return["minions"])
# start batching even if not all minions respond to ping
yield tornado.gen.sleep(
self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
)
if self.event:
yield self.start_batch()
try:
async with asyncio.timeout(
self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
):
while True:
await asyncio.sleep(0.03)
if self.targeted_minions == self.minions:
break
except TimeoutError:
# Some minions are down, scheduling batch anyway
pass

@tornado.gen.coroutine
def start(self):
"""
Start the batch execution
"""
if not self.event:
return
self.__set_event_handler()
ping_return = yield self.local.run_job_async(
self.opts["tgt"],
"test.ping",
[],
self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")),
gather_job_timeout=self.opts["gather_job_timeout"],
jid=self.ping_jid,
metadata=self.metadata,
**self.eauth,
)
self.targeted_minions = set(ping_return["minions"])
# start batching even if not all minions respond to ping
yield tornado.gen.sleep(
self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
)
if self.event:
self.event.io_loop.spawn_callback(self.start_batch)
await self.start_batch()

@tornado.gen.coroutine
def start_batch(self):
async def start_batch(self):
"""
Fire `salt/batch/*/start` and continue batch with `run_next`
"""
Expand All @@ -448,17 +428,17 @@ def start_batch(self):
"down_minions": self.targeted_minions.difference(self.minions),
"metadata": self.metadata,
}
yield self.events_channel.master_event.fire_event_async(
data, f"salt/batch/{self.batch_jid}/start"
ret = self.event.fire_event(
data, "salt/batch/{}/start".format(self.batch_jid)
)
if self.event:
yield self.run_next()
await self.run_next()

@tornado.gen.coroutine
def end_batch(self):
async def end_batch(self):
"""
End the batch and call safe closing
"""
log.trace("BatchAsync.end_batch called")
left = self.minions.symmetric_difference(
self.done_minions.union(self.timedout_minions)
)
Expand All @@ -474,46 +454,50 @@ def end_batch(self):
"timedout_minions": self.timedout_minions,
"metadata": self.metadata,
}
yield self.events_channel.master_event.fire_event_async(
ret = self.event.fire_event(
data, f"salt/batch/{self.batch_jid}/done"
)

# release to the IOLoop to allow the event to be published
# before closing batch async execution
yield tornado.gen.sleep(1)
await asyncio.sleep(0.03)
self.close_safe()

def close_safe(self):
log.trace("BatchAsync.close_safe called")
if self.events_channel is not None:
self.events_channel.unsubscribe(None, None, id(self))
self.events_channel.unuse(id(self))
self.events_channel = None
_destroy_unused_shared_events_channel()
self.event = None

@tornado.gen.coroutine
def schedule_next(self):
async def schedule_next(self):
log.trace("BatchAsync.schedule_next called")
if self.scheduled:
log.trace("BatchAsync.schedule_next -> Batch already scheduled, nothing to do.")
return
self.scheduled = True
# call later so that we maybe gather more returns
yield tornado.gen.sleep(self.batch_delay)
if self._get_next():
# call later so that we maybe gather more returns
log.trace("BatchAsync.schedule_next delaying batch {} second(s).".format(self.batch_delay))
await asyncio.sleep(self.batch_delay)
if self.event:
yield self.run_next()
await self.run_next()

@tornado.gen.coroutine
def run_next(self):
async def run_next(self):
"""
Continue batch execution with the next targets
"""
self.scheduled = False
next_batch = self._get_next()
log.trace("BatchAsync.run_next called. Next Batch -> {}".format(next_batch))
if not next_batch:
yield self.end_batch()
await self.end_batch()
return
self.active = self.active.union(next_batch)
try:
ret = yield self.events_channel.local_client.run_job_async(
await self.events_channel.local_client.run_job_async(
next_batch,
self.opts["fun"],
self.opts["arg"],
Expand All @@ -529,11 +513,11 @@ def run_next(self):
**self.extra_job_kwargs,
)

yield tornado.gen.sleep(self.opts["timeout"])
await asyncio.sleep(self.opts["timeout"])

# The batch can be done already at this point, which means no self.event
if self.event:
yield self.find_job(set(next_batch))
if self.event and self.active.intersection(next_batch):
await self.find_job(set(next_batch))
except Exception as ex: # pylint: disable=W0703
log.error(
"Error in scheduling next batch: %s. Aborting execution",
Expand Down
2 changes: 1 addition & 1 deletion salt/transport/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ async def _send_recv(self, message):
try:
await self.socket.send(message)
ret = await self.socket.recv()
except zmq.error.ZMQError:
except (zmq.error.ZMQError, AttributeError):
self.close()
await self.connect()
await self.socket.send(message)
Expand Down
9 changes: 0 additions & 9 deletions salt/utils/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,15 +924,6 @@ def fire_ret_load(self, load):
# Minion fired a bad retcode, fire an event
self._fire_ret_load_specific_fun(load)

def remove_event_handler(self, event_handler):
"""
Remove the event_handler callback

.. versionadded:: 3007.0
"""
if event_handler in self.subscriber.callbacks:
self.subscriber.callbacks.remove(event_handler)

def set_event_handler(self, event_handler):
"""
Invoke the event_handler callback each time an event arrives.
Expand Down