diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bcc95b..5f45c85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Fixed +- make all agent threads thread-safe +- fix dict items iteration when doing reboot due to python3 changes ## [1.8.7] - 2019-11-21 ### Changed diff --git a/talker_agent/talker.py b/talker_agent/talker.py index 24dfc57..da0ae09 100644 --- a/talker_agent/talker.py +++ b/talker_agent/talker.py @@ -20,6 +20,7 @@ from logging import getLogger from logging.handlers import RotatingFileHandler from configparser import ConfigParser +from threading import Lock import redis @@ -59,6 +60,29 @@ def reraise(tp, value, tb=None): JOBS_EXPIRATION = 15 # 20 * 60 # how long to keep job ids in the EOS registry (exactly-once-semantics) config = None +first_exception_info = None +safe_thread_lock = Lock() + + +class SafeThread(threading.Thread): + def __init__(self, *, target, name, args=(), kwargs=None, daemon=None): + super().__init__(None, target, name, args, kwargs, daemon=daemon) + self.exc_info = None + + def run(self): + global first_exception_info + try: + self._target(*self._args, **self._kwargs) + except: + exc_info = sys.exc_info() + logger.info("exception in '%s'", self.name, exc_info=exc_info) + with safe_thread_lock: + if not first_exception_info: + first_exception_info = sys.exc_info() + finally: + # Avoid a refcycle if the thread is running a function with + # an argument that has a member that points to the thread. + del self._target, self._args, self._kwargs class LineTimeout(Exception): @@ -439,9 +463,7 @@ def _kill(): except Exception as e: self.logger.error(e) - thread = threading.Thread(target=_kill, name="killer-%s" % self.job_id) - thread.daemon = True - thread.start() + SafeThread(target=_kill, name="killer-%s" % self.job_id, daemon=True).start() self.reset_timeout(new_timeout=graceful_timeout + 10) @@ -451,7 +473,7 @@ def __init__(self, *args, **kwargs): super(RebootJob, self).__init__(*args, **kwargs) def start(self): - threading.Thread(target=self.reboot_host, name="Reboot").start() + SafeThread(target=self.reboot_host, name="Reboot").start() def reboot_host(self): with open(REBOOT_FILENAME, 'w') as f: @@ -497,8 +519,6 @@ def __init__(self): self.output_lock = threading.RLock() self.redis = None self.host_id = None - self.redis_fetcher = None - self.redis_sender = None self.job_poller = None self.fds_poller = select.poll() self.fds_to_channels = {} @@ -620,7 +640,7 @@ def stop_for_reboot(self, requested_by): requested_by.log("Some jobs not yet finished, setting exit code to 'reboot' and proceeding") with self.pipeline() as pipeline: - for job_id, job in self.current_processes.items(): + for job_id, job in list(self.current_processes.items()): if job_id == requested_by.job_id: continue job.set_result('reboot') @@ -744,35 +764,24 @@ def sync_jobs_progress(self): else: time.sleep(CYCLE_DURATION) - def start_worker(self, worker, name): - - def safe_run(): - try: - return worker() - except: # noqa - self.exc_info = sys.exc_info() - logger.debug("exception in '%s'", name, exc_info=self.exc_info) - - t = threading.Thread(target=safe_run, name=name) - t.daemon = True - t.start() - return t - def start(self): + global first_exception_info + first_exception_info = None + self.finalize_previous_session() if os.path.isfile(JOBS_SEEN): with open(JOBS_SEEN, "r") as f: self.seen_jobs = json.load(f) - self.redis_fetcher = self.start_worker(self.fetch_new_jobs, name="RedisFetcher") - self.redis_sender = self.start_worker(self.sync_jobs_progress, name="JobProgress") + SafeThread(target=self.fetch_new_jobs, name="RedisFetcher", daemon=True).start() + SafeThread(target=self.sync_jobs_progress, name="JobProgress", daemon=True).start() while not self.stop_agent.is_set(): if not self.get_jobs_outputs(): time.sleep(CYCLE_DURATION / 10.0) - if self.exc_info: + if first_exception_info: logger.debug("re-raising exception from worker") - reraise(*self.exc_info) + reraise(*first_exception_info) assert False, "exception should have been raised" def setup(self): @@ -819,7 +828,7 @@ def unregister_fileno(self, fileno): def wait_proc(proc, timeout): - t = threading.Thread(target=proc.wait) + t = SafeThread(target=proc.wait, name='wait_proc') t.start() t.join(timeout) return not t.is_alive() diff --git a/tests/uts/test_agent.py b/tests/uts/test_agent.py index 445dd6c..65f300a 100644 --- a/tests/uts/test_agent.py +++ b/tests/uts/test_agent.py @@ -33,14 +33,24 @@ def raise_file_not_found(*args, **kwargs): raise OSError(2, 'No such file or directory') +class RebootMockException(Exception): + pass + + +def reboot_mock_exception(*args): + raise RebootMockException("This is reboot mock exception") + + JOBS_DIR = '/tmp/talker/jobs' EXCEPTION_FILENAME = '/tmp/talker/last_exception' JOBS_SEEN = os.path.join(JOBS_DIR, 'eos.json') +REBOOT_FILENAME = '/tmp/talker/reboot.id' @patch('talker_agent.talker.JOBS_DIR', JOBS_DIR) @patch('talker_agent.talker.JOBS_SEEN', JOBS_SEEN) @patch('talker_agent.talker.EXCEPTION_FILENAME', EXCEPTION_FILENAME) +@patch('talker_agent.talker.REBOOT_FILENAME', REBOOT_FILENAME) class TestAgent(unittest.TestCase): def setUp(self): @@ -167,3 +177,8 @@ def test_max_output_per_channel(self): res = get_stdout(self.agent.redis, job_id) expected_val = val.replace('\\n', '\n') * val_repeats self.assertEqual(res, expected_val) + + @patch('talker_agent.talker.RebootJob.reboot_host', reboot_mock_exception) + def test_safe_thread(self): + _ = self.run_cmd_on_agent('reboot', force=True) + self.assert_agent_exception(RebootMockException)