From 013171617f3c3402408431bbd260d30c11944635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jelmer=20Vernoo=C4=B3?= Date: Sun, 28 Jun 2020 16:07:30 +0000 Subject: [PATCH] Add a RefsContainer.watch interface. Fixes #751 --- dulwich/refs.py | 60 +++++++++++++++++++++++++++++++++----- dulwich/tests/test_refs.py | 4 +-- 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/dulwich/refs.py b/dulwich/refs.py index aded324bd..75e1cba10 100644 --- a/dulwich/refs.py +++ b/dulwich/refs.py @@ -393,8 +393,7 @@ def get_symrefs(self): def watch(self): """Watch for changes to the refs in this container. - Returns a context manager that yields tuples with (refname, old_sha, - new_sha) + Returns a context manager that yields tuples with (refname, new_sha) """ raise NotImplementedError(self.watch) @@ -443,9 +442,9 @@ def read_loose_ref(self, name): def get_packed_refs(self): return {} - def _notify(self, ref, oldsha, newsha): + def _notify(self, ref, newsha): for watcher in self._watchers: - watcher._notify((ref, oldsha, newsha)) + watcher._notify((ref, newsha)) def watch(self): return _DictRefsWatcher(self) @@ -455,7 +454,7 @@ def set_symbolic_ref(self, name, other, committer=None, old = self.follow(name)[-1] new = SYMREF + other self._refs[name] = new - self._notify(name, old, new) + self._notify(name, new) self._log(name, old, new, committer=committer, timestamp=timestamp, timezone=timezone, message=message) @@ -468,7 +467,7 @@ def set_if_equals(self, name, old_ref, new_ref, committer=None, self._check_refname(realname) old = self._refs.get(realname) self._refs[realname] = new_ref - self._notify(realname, old, new_ref) + self._notify(realname, new_ref) self._log(realname, old, new_ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True @@ -478,7 +477,7 @@ def add_if_new(self, name, ref, committer=None, timestamp=None, if name in self._refs: return False self._refs[name] = ref - self._notify(name, None, ref) + self._notify(name, ref) self._log(name, None, ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True @@ -492,7 +491,7 @@ def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, except KeyError: pass else: - self._notify(name, old, None) + self._notify(name, None) self._log(name, old, None, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True @@ -546,6 +545,47 @@ def get_peeled(self, name): return self._refs[name] +class _InotifyRefsWatcher(object): + + def __init__(self, path): + import pyinotify + from queue import Queue + self.path = os.fsdecode(path) + self.manager = pyinotify.WatchManager() + self.manager.add_watch( + self.path, pyinotify.IN_DELETE | + pyinotify.IN_CLOSE_WRITE | pyinotify.IN_MOVED_TO, rec=True, + auto_add=True) + + self.notifier = pyinotify.ThreadedNotifier( + self.manager, default_proc_fun=self._notify) + self.queue = Queue() + + def _notify(self, event): + if event.dir: + return + if event.pathname.endswith('.lock'): + return + ref = os.fsencode(os.path.relpath(event.pathname, self.path)) + if event.maskname == 'IN_DELETE': + self.queue.put_nowait((ref, None)) + elif event.maskname in ('IN_CLOSE_WRITE', 'IN_MOVED_TO'): + with open(event.pathname, 'rb') as f: + sha = f.readline().rstrip(b'\n\r') + self.queue.put_nowait((ref, sha)) + + def __next__(self): + return self.queue.get() + + def __enter__(self): + self.notifier.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.notifier.stop() + return False + + class DiskRefsContainer(RefsContainer): """Refs container that reads refs from disk.""" @@ -892,6 +932,10 @@ def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, return True + def watch(self): + import pyinotify # noqa: F401 + return _InotifyRefsWatcher(self.path) + def _split_ref_line(line): """Split a single ref line into a tuple of SHA1 and name.""" diff --git a/dulwich/tests/test_refs.py b/dulwich/tests/test_refs.py index b2b247956..ac0ec5136 100644 --- a/dulwich/tests/test_refs.py +++ b/dulwich/tests/test_refs.py @@ -333,20 +333,18 @@ def test_watch(self): b'48d01bd4b77fed026b154d16493e5deab78f02ec') change = next(watcher) self.assertEqual( - (b'refs/remotes/origin/other', None, + (b'refs/remotes/origin/other', b'48d01bd4b77fed026b154d16493e5deab78f02ec'), change) self._refs[b'refs/remotes/origin/other'] = ( b'48d01bd4b77fed026b154d16493e5deab78f02ed') change = next(watcher) self.assertEqual( (b'refs/remotes/origin/other', - b'48d01bd4b77fed026b154d16493e5deab78f02ec', b'48d01bd4b77fed026b154d16493e5deab78f02ed'), change) del self._refs[b'refs/remotes/origin/other'] change = next(watcher) self.assertEqual( (b'refs/remotes/origin/other', - b'48d01bd4b77fed026b154d16493e5deab78f02ed', None), change)