Skip to content

Commit

Permalink
Fix use after free errors (#177)
Browse files Browse the repository at this point in the history
* Fix use after free errors

Fixes a few cases where use after free could happen, and one where it can be
easily reproduced.

When using the FSEvents back-end, both the `Watcher` and a `State` pointers are
passed as part of the context parameter to a C style callback.

Both of these pointers may be free-ed before the callback is called before this
commit.

After this commit, these two pointers, as well as all most others in the
codebase are replaced with `shared_ptr`.

This is a lazy fix and goes against C++ core guidelines - "F.7".

A secondary pass would use reference counting only in the shared ownership
cases. For most back-ends this shared ownership happens on: error types,
subscription maps/sets and state structs.
  • Loading branch information
yamadapc committed Jun 13, 2024
1 parent 8dfd994 commit e9af634
Show file tree
Hide file tree
Showing 23 changed files with 210 additions and 196 deletions.
12 changes: 6 additions & 6 deletions src/Backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,23 @@ Backend::~Backend() {
#endif
}

void Backend::watch(Watcher &watcher) {
void Backend::watch(WatcherRef watcher) {
std::unique_lock<std::mutex> lock(mMutex);
auto res = mSubscriptions.find(&watcher);
auto res = mSubscriptions.find(watcher);
if (res == mSubscriptions.end()) {
try {
this->subscribe(watcher);
mSubscriptions.insert(&watcher);
mSubscriptions.insert(watcher);
} catch (std::exception &err) {
unref();
throw;
}
}
}

void Backend::unwatch(Watcher &watcher) {
void Backend::unwatch(WatcherRef watcher) {
std::unique_lock<std::mutex> lock(mMutex);
size_t deleted = mSubscriptions.erase(&watcher);
size_t deleted = mSubscriptions.erase(watcher);
if (deleted > 0) {
this->unsubscribe(watcher);
unref();
Expand All @@ -168,7 +168,7 @@ void Backend::unref() {
}

void Backend::handleWatcherError(WatcherError &err) {
unwatch(*err.mWatcher);
unwatch(err.mWatcher);
err.mWatcher->notifyError(err);
}

Expand Down
14 changes: 7 additions & 7 deletions src/Backend.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ public:
void notifyStarted();

virtual void start();
virtual void writeSnapshot(Watcher &watcher, std::string *snapshotPath) = 0;
virtual void getEventsSince(Watcher &watcher, std::string *snapshotPath) = 0;
virtual void subscribe(Watcher &watcher) = 0;
virtual void unsubscribe(Watcher &watcher) = 0;
virtual void writeSnapshot(WatcherRef watcher, std::string *snapshotPath) = 0;
virtual void getEventsSince(WatcherRef watcher, std::string *snapshotPath) = 0;
virtual void subscribe(WatcherRef watcher) = 0;
virtual void unsubscribe(WatcherRef watcher) = 0;

static std::shared_ptr<Backend> getShared(std::string backend);

void watch(Watcher &watcher);
void unwatch(Watcher &watcher);
void watch(WatcherRef watcher);
void unwatch(WatcherRef watcher);
void unref();
void handleWatcherError(WatcherError &err);

std::mutex mMutex;
std::thread mThread;
private:
std::unordered_set<Watcher *> mSubscriptions;
std::unordered_set<WatcherRef> mSubscriptions;
Signal mStartedSignal;

void handleError(std::exception &err);
Expand Down
10 changes: 5 additions & 5 deletions src/Watcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
using namespace Napi;

struct WatcherHash {
std::size_t operator() (std::shared_ptr<Watcher> const &k) const {
std::size_t operator() (WatcherRef const &k) const {
return std::hash<std::string>()(k->mDir);
}
};

struct WatcherCompare {
size_t operator() (std::shared_ptr<Watcher> const &a, std::shared_ptr<Watcher> const &b) const {
size_t operator() (WatcherRef const &a, WatcherRef const &b) const {
return *a == *b;
}
};

static std::unordered_set<std::shared_ptr<Watcher>, WatcherHash, WatcherCompare> sharedWatchers;
static std::unordered_set<WatcherRef , WatcherHash, WatcherCompare> sharedWatchers;

std::shared_ptr<Watcher> Watcher::getShared(std::string dir, std::unordered_set<std::string> ignorePaths, std::unordered_set<Glob> ignoreGlobs) {
std::shared_ptr<Watcher> watcher = std::make_shared<Watcher>(dir, ignorePaths, ignoreGlobs);
WatcherRef Watcher::getShared(std::string dir, std::unordered_set<std::string> ignorePaths, std::unordered_set<Glob> ignoreGlobs) {
WatcherRef watcher = std::make_shared<Watcher>(dir, ignorePaths, ignoreGlobs);
auto found = sharedWatchers.find(watcher);
if (found != sharedWatchers.end()) {
return *found;
Expand Down
18 changes: 13 additions & 5 deletions src/Watcher.hh
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,26 @@

using namespace Napi;

struct Watcher;
using WatcherRef = std::shared_ptr<Watcher>;

struct Callback {
Napi::ThreadSafeFunction tsfn;
Napi::FunctionReference ref;
std::thread::id threadId;
};

class WatcherState {
public:
virtual ~WatcherState() = default;
};

struct Watcher {
std::string mDir;
std::unordered_set<std::string> mIgnorePaths;
std::unordered_set<Glob> mIgnoreGlobs;
EventList mEvents;
void *state;
std::shared_ptr<WatcherState> state;

Watcher(std::string dir, std::unordered_set<std::string> ignorePaths, std::unordered_set<Glob> ignoreGlobs);
~Watcher();
Expand All @@ -42,7 +50,7 @@ struct Watcher {
bool isIgnored(std::string path);
void destroy();

static std::shared_ptr<Watcher> getShared(std::string dir, std::unordered_set<std::string> ignorePaths, std::unordered_set<Glob> ignoreGlobs);
static WatcherRef getShared(std::string dir, std::unordered_set<std::string> ignorePaths, std::unordered_set<Glob> ignoreGlobs);

private:
std::mutex mMutex;
Expand All @@ -57,9 +65,9 @@ private:

class WatcherError : public std::runtime_error {
public:
Watcher *mWatcher;
WatcherError(std::string msg, Watcher *watcher) : std::runtime_error(msg), mWatcher(watcher) {}
WatcherError(const char *msg, Watcher *watcher) : std::runtime_error(msg), mWatcher(watcher) {}
WatcherRef mWatcher;
WatcherError(std::string msg, WatcherRef watcher) : std::runtime_error(msg), mWatcher(watcher) {}
WatcherError(const char *msg, WatcherRef watcher) : std::runtime_error(msg), mWatcher(watcher) {}
};

#endif
16 changes: 8 additions & 8 deletions src/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ class WriteSnapshotRunner : public PromiseRunner {
}
private:
std::shared_ptr<Backend> backend;
std::shared_ptr<Watcher> watcher;
WatcherRef watcher;
std::string snapshotPath;

void execute() override {
backend->writeSnapshot(*watcher, &snapshotPath);
backend->writeSnapshot(watcher, &snapshotPath);
}
};

Expand All @@ -107,11 +107,11 @@ class GetEventsSinceRunner : public PromiseRunner {
}
private:
std::shared_ptr<Backend> backend;
std::shared_ptr<Watcher> watcher;
WatcherRef watcher;
std::string snapshotPath;

void execute() override {
backend->getEventsSince(*watcher, &snapshotPath);
backend->getEventsSince(watcher, &snapshotPath);
}

Value getResult() override {
Expand Down Expand Up @@ -169,13 +169,13 @@ class SubscribeRunner : public PromiseRunner {
}

private:
std::shared_ptr<Watcher> watcher;
WatcherRef watcher;
std::shared_ptr<Backend> backend;
FunctionReference callback;

void execute() override {
try {
backend->watch(*watcher);
backend->watch(watcher);
} catch (std::exception &err) {
watcher->destroy();
throw;
Expand All @@ -197,13 +197,13 @@ class UnsubscribeRunner : public PromiseRunner {
}

private:
std::shared_ptr<Watcher> watcher;
WatcherRef watcher;
std::shared_ptr<Backend> backend;
bool shouldUnwatch;

void execute() override {
if (shouldUnwatch) {
backend->unwatch(*watcher);
backend->unwatch(watcher);
}
}
};
Expand Down
20 changes: 10 additions & 10 deletions src/kqueue/KqueueBackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void KqueueBackend::start() {
}

// Track all of the watchers that are touched so we can notify them at the end of the events.
std::unordered_set<Watcher *> watchers;
std::unordered_set<WatcherRef> watchers;

for (int i = 0; i < event_count; i++) {
int flags = events[i].fflags;
Expand Down Expand Up @@ -118,20 +118,20 @@ KqueueBackend::~KqueueBackend() {
mEndedSignal.wait();
}

void KqueueBackend::subscribe(Watcher &watcher) {
void KqueueBackend::subscribe(WatcherRef watcher) {
// Build a full directory tree recursively, and watch each directory.
std::shared_ptr<DirTree> tree = getTree(watcher);

for (auto it = tree->entries.begin(); it != tree->entries.end(); it++) {
bool success = watchDir(watcher, it->second.path, tree);
if (!success) {
throw WatcherError(std::string("error watching " + watcher.mDir + ": " + strerror(errno)), &watcher);
throw WatcherError(std::string("error watching " + watcher->mDir + ": " + strerror(errno)), watcher);
}
}
}

bool KqueueBackend::watchDir(Watcher &watcher, std::string path, std::shared_ptr<DirTree> tree) {
if (watcher.isIgnored(path)) {
bool KqueueBackend::watchDir(WatcherRef watcher, std::string path, std::shared_ptr<DirTree> tree) {
if (watcher->isIgnored(path)) {
return false;
}

Expand All @@ -141,7 +141,7 @@ bool KqueueBackend::watchDir(Watcher &watcher, std::string path, std::shared_ptr
}

KqueueSubscription sub = {
.watcher = &watcher,
.watcher = watcher,
.path = path,
.tree = tree
};
Expand Down Expand Up @@ -189,7 +189,7 @@ std::vector<KqueueSubscription *> KqueueBackend::findSubscriptions(std::string &
return subs;
}

bool KqueueBackend::compareDir(int fd, std::string &path, std::unordered_set<Watcher *> &watchers) {
bool KqueueBackend::compareDir(int fd, std::string &path, std::unordered_set<WatcherRef> &watchers) {
// macOS doesn't support fdclosedir, so we have to duplicate the file descriptor
// to ensure the closedir doesn't also stop watching.
#if __APPLE__
Expand Down Expand Up @@ -240,7 +240,7 @@ bool KqueueBackend::compareDir(int fd, std::string &path, std::unordered_set<Wat
sub->watcher->mEvents.create(fullpath);
watchers.emplace(sub->watcher);

bool success = watchDir(*sub->watcher, fullpath, sub->tree);
bool success = watchDir(sub->watcher, fullpath, sub->tree);
if (!success) {
sub->tree->remove(fullpath);
return false;
Expand Down Expand Up @@ -289,10 +289,10 @@ bool KqueueBackend::compareDir(int fd, std::string &path, std::unordered_set<Wat
return true;
}

void KqueueBackend::unsubscribe(Watcher &watcher) {
void KqueueBackend::unsubscribe(WatcherRef watcher) {
// Find any subscriptions pointing to this watcher, and remove them.
for (auto it = mSubscriptions.begin(); it != mSubscriptions.end();) {
if (it->second.watcher == &watcher) {
if (it->second.watcher.get() == watcher.get()) {
if (mSubscriptions.count(it->first) == 1) {
// Closing the file descriptor automatically unwatches it in the kqueue.
close(it->second.fd);
Expand Down
10 changes: 5 additions & 5 deletions src/kqueue/KqueueBackend.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "../Signal.hh"

struct KqueueSubscription {
Watcher *watcher;
WatcherRef watcher;
std::string path;
std::shared_ptr<DirTree> tree;
int fd;
Expand All @@ -18,17 +18,17 @@ class KqueueBackend : public BruteForceBackend {
public:
void start() override;
~KqueueBackend();
void subscribe(Watcher &watcher) override;
void unsubscribe(Watcher &watcher) override;
void subscribe(WatcherRef watcher) override;
void unsubscribe(WatcherRef watcher) override;
private:
int mKqueue;
int mPipe[2];
std::unordered_multimap<std::string, KqueueSubscription> mSubscriptions;
std::unordered_map<int, DirEntry *> mFdToEntry;
Signal mEndedSignal;

bool watchDir(Watcher &watcher, std::string path, std::shared_ptr<DirTree> tree);
bool compareDir(int fd, std::string &dir, std::unordered_set<Watcher *> &watchers);
bool watchDir(WatcherRef watcher, std::string path, std::shared_ptr<DirTree> tree);
bool compareDir(int fd, std::string &dir, std::unordered_set<WatcherRef> &watchers);
std::vector<KqueueSubscription *> findSubscriptions(std::string &path);
};

Expand Down
22 changes: 11 additions & 11 deletions src/linux/InotifyBackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,21 @@ InotifyBackend::~InotifyBackend() {
}

// This function is called by Backend::watch which takes a lock on mMutex
void InotifyBackend::subscribe(Watcher &watcher) {
void InotifyBackend::subscribe(WatcherRef watcher) {
// Build a full directory tree recursively, and watch each directory.
std::shared_ptr<DirTree> tree = getTree(watcher);

for (auto it = tree->entries.begin(); it != tree->entries.end(); it++) {
if (it->second.isDir) {
bool success = watchDir(watcher, it->second.path, tree);
if (!success) {
throw WatcherError(std::string("inotify_add_watch on '") + it->second.path + std::string("' failed: ") + strerror(errno), &watcher);
throw WatcherError(std::string("inotify_add_watch on '") + it->second.path + std::string("' failed: ") + strerror(errno), watcher);
}
}
}
}

bool InotifyBackend::watchDir(Watcher &watcher, std::string path, std::shared_ptr<DirTree> tree) {
bool InotifyBackend::watchDir(WatcherRef watcher, std::string path, std::shared_ptr<DirTree> tree) {
int wd = inotify_add_watch(mInotify, path.c_str(), INOTIFY_MASK);
if (wd == -1) {
return false;
Expand All @@ -87,7 +87,7 @@ bool InotifyBackend::watchDir(Watcher &watcher, std::string path, std::shared_pt
std::shared_ptr<InotifySubscription> sub = std::make_shared<InotifySubscription>();
sub->tree = tree;
sub->path = path;
sub->watcher = &watcher;
sub->watcher = watcher;
mSubscriptions.emplace(wd, sub);

return true;
Expand All @@ -98,7 +98,7 @@ void InotifyBackend::handleEvents() {
struct inotify_event *event;

// Track all of the watchers that are touched so we can notify them at the end of the events.
std::unordered_set<Watcher *> watchers;
std::unordered_set<WatcherRef> watchers;

while (true) {
int n = read(mInotify, &buf, BUFFER_SIZE);
Expand Down Expand Up @@ -131,7 +131,7 @@ void InotifyBackend::handleEvents() {
}
}

void InotifyBackend::handleEvent(struct inotify_event *event, std::unordered_set<Watcher *> &watchers) {
void InotifyBackend::handleEvent(struct inotify_event *event, std::unordered_set<WatcherRef> &watchers) {
std::unique_lock<std::mutex> lock(mMutex);

// Find the subscriptions for this watch descriptor
Expand All @@ -150,7 +150,7 @@ void InotifyBackend::handleEvent(struct inotify_event *event, std::unordered_set

bool InotifyBackend::handleSubscription(struct inotify_event *event, std::shared_ptr<InotifySubscription> sub) {
// Build full path and check if its in our ignore list.
Watcher *watcher = sub->watcher;
std::shared_ptr<Watcher> watcher = sub->watcher;
std::string path = std::string(sub->path);
bool isDir = event->mask & IN_ISDIR;

Expand All @@ -174,7 +174,7 @@ bool InotifyBackend::handleSubscription(struct inotify_event *event, std::shared
DirEntry *entry = sub->tree->add(path, CONVERT_TIME(st.st_mtim), S_ISDIR(st.st_mode));

if (entry->isDir) {
bool success = watchDir(*watcher, path, sub->tree);
bool success = watchDir(watcher, path, sub->tree);
if (!success) {
sub->tree->remove(path);
return false;
Expand Down Expand Up @@ -213,14 +213,14 @@ bool InotifyBackend::handleSubscription(struct inotify_event *event, std::shared
}

// This function is called by Backend::unwatch which takes a lock on mMutex
void InotifyBackend::unsubscribe(Watcher &watcher) {
void InotifyBackend::unsubscribe(WatcherRef watcher) {
// Find any subscriptions pointing to this watcher, and remove them.
for (auto it = mSubscriptions.begin(); it != mSubscriptions.end();) {
if (it->second->watcher == &watcher) {
if (it->second->watcher.get() == watcher.get()) {
if (mSubscriptions.count(it->first) == 1) {
int err = inotify_rm_watch(mInotify, it->first);
if (err == -1) {
throw WatcherError(std::string("Unable to remove watcher: ") + strerror(errno), &watcher);
throw WatcherError(std::string("Unable to remove watcher: ") + strerror(errno), watcher);
}
}

Expand Down
Loading

0 comments on commit e9af634

Please sign in to comment.