Skip to content

Commit

Permalink
Add automatic reconnect, passing test
Browse files Browse the repository at this point in the history
  • Loading branch information
joshkunz committed May 5, 2023
1 parent 5ed0f78 commit f61dfcd
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 90 deletions.
16 changes: 9 additions & 7 deletions src/ashuffle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ struct MPDHost {
}
};

} // namespace

std::optional<std::unique_ptr<Loader>> Reloader(mpd::MPD *mpd,
const Options &options) {
// Nothing we can do when `--file` is provided. The user is just stuck
Expand All @@ -190,8 +192,6 @@ std::optional<std::unique_ptr<Loader>> Reloader(mpd::MPD *mpd,
return std::make_unique<MPDLoader>(mpd, options.ruleset, options.group_by);
}

} // namespace

/* Keep adding songs when the queue runs out */
absl::Status Loop(mpd::MPD *mpd, ShuffleChain *songs, const Options &options,
TestDelegate test_d) {
Expand Down Expand Up @@ -273,7 +273,7 @@ absl::Status Loop(mpd::MPD *mpd, ShuffleChain *songs, const Options &options,

absl::StatusOr<std::unique_ptr<mpd::MPD>> Connect(
const mpd::Dialer &d, const Options &options,
std::function<std::string()> &getpass_f) {
std::function<std::string()> *getpass_f) {
/* Attempt to get host from command line if available. Otherwise use
* MPD_HOST variable if available. Otherwise use 'localhost'. */
const char *env_host =
Expand Down Expand Up @@ -324,11 +324,13 @@ absl::StatusOr<std::unique_ptr<mpd::MPD>> Connect(
auth.status().ToString());
return auth.status();
}
if (!mpd_host.password && !auth->authorized) {
if (!mpd_host.password && !auth->authorized && getpass_f != nullptr) {
// If the user did *not* supply a password, and we are missing a
// required command, then try to prompt the user to provide a password.
// Once we get/apply a password, try the required commands again...
PromptPassword(mpd.get(), getpass_f);
// required command, and we're in an interactive mode (where we have
// the ability to prompt for a password) then try to prompt the user to
// provide a password. Once we get/apply a password, try the required
// commands again...
PromptPassword(mpd.get(), *getpass_f);
auth = mpd->CheckCommands(required);
if (!auth.ok()) {
Log().Error("Failed to check required commands: %s",
Expand Down
22 changes: 18 additions & 4 deletions src/ashuffle.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,27 @@
#include <mpd/client.h>

#include "args.h"
#include "load.h"
#include "mpd.h"
#include "rule.h"
#include "shuffle.h"

namespace ashuffle {

namespace {

// A getpass_f value that can be used in non-interactive mode.
std::function<std::string()>* kNonInteractiveGetpass = nullptr;

} // namespace

// `MPD_PORT` environment variables. If a password is needed, no password can
// be found in MPD_HOST, then `getpass_f' will be used to prompt the user
// for a password. If `getpass_f' is NULL, the a default password prompt
// (based on getpass) will be used.
// for a password. If `getpass_f' is NULL, then a password will not be
// prompted.
absl::StatusOr<std::unique_ptr<mpd::MPD>> Connect(
const mpd::Dialer& d, const Options& options,
std::function<std::string()>& getpass_f);
std::function<std::string()>* getpass_f);

struct TestDelegate {
bool (*until_f)() = nullptr;
Expand All @@ -38,7 +46,13 @@ struct TestDelegate {
absl::Status Loop(mpd::MPD* mpd, ShuffleChain* songs, const Options& options,
TestDelegate d = TestDelegate());

// Print the size of the database to the given stream, accounting for grouping.
// Return a loader capable of re-loading the current shuffle chain given
// a particular set of options. If it's not possible to create such a
// loader, returns an empty option.
std::optional<std::unique_ptr<Loader>> Reloader(mpd::MPD* mpd,
const Options& options);
// Print the size of the database to the given stream, accounting for
// grouping.
void PrintChainLength(std::ostream& stream, const ShuffleChain& chain);

} // namespace ashuffle
Expand Down
13 changes: 13 additions & 0 deletions src/log.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
#include <iostream>

namespace ashuffle {

std::ostream& operator<<(std::ostream& os, const Log::Level& level) {
switch (level) {
case Log::Level::kInfo:
os << "INFO";
break;
case Log::Level::kError:
os << "ERROR";
break;
}
return os;
}

namespace log {

std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
Expand Down
26 changes: 16 additions & 10 deletions src/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,37 @@ class Log final {
WriteLog(Level::kInfo, fmt, args...);
}

void InfoStr(std::string_view message) {
WriteLogStr(Level::kInfo, message);
}

template <typename... Args>
void Error(const absl::FormatSpec<Args...>& fmt, Args... args) {
WriteLog(Level::kError, fmt, args...);
}

void ErrorStr(std::string_view message) {
WriteLogStr(Level::kError, message);
}

private:
enum class Level {
kInfo,
kError,
};

friend std::ostream& operator<<(std::ostream&, const Level&);

void WriteLogStr(Level level, std::string_view message) {
log::DefaultLogger().Stream()
<< level << " " << loc_ << ": " << message << std::endl;
}

template <typename... Args>
void WriteLog(Level level, const absl::FormatSpec<Args...>& fmt,
Args... args) {
const char* level_str = nullptr;
switch (level) {
case Level::kInfo:
level_str = "INFO";
break;
case Level::kError:
level_str = "ERROR";
break;
}
log::DefaultLogger().Stream()
<< level_str << " " << loc_ << ": " << absl::StrFormat(fmt, args...)
<< level << " " << loc_ << ": " << absl::StrFormat(fmt, args...)
<< std::endl;
}

Expand Down
81 changes: 73 additions & 8 deletions src/main.cc
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
#include <stdlib.h>
#include <time.h>
#include <cassert>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <string>
#include <variant>
#include <vector>

#include <absl/time/clock.h>
#include <mpd/connection.h>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/time/time.h"
#include "args.h"
#include "ashuffle.h"
#include "getpass.h"
#include "load.h"
#include "log.h"
#include "mpd_client.h"
#include "shuffle.h"
#include "version.h"

using namespace ashuffle;

namespace {
// This is the maximum amount of time that ashuffle is allowed to be
// disconnected from MPD before it exits. This catches cases where the
// environment changes and it would be impossible for ashuffle to reconnect.
const absl::Duration kMaxDisconnectedTime = absl::Seconds(10);

// The amount of time to wait between reconnection attempts.
const absl::Duration kReconnectWait = absl::Milliseconds(250);

namespace {
std::unique_ptr<Loader> BuildLoader(mpd::MPD* mpd, const Options& opts) {
if (opts.file_in != nullptr && opts.check_uris) {
return std::make_unique<FileMPDLoader>(mpd, opts.ruleset, opts.group_by,
Expand All @@ -32,6 +45,19 @@ std::unique_ptr<Loader> BuildLoader(mpd::MPD* mpd, const Options& opts) {
return std::make_unique<MPDLoader>(mpd, opts.ruleset, opts.group_by);
}

void LoopOnce(mpd::MPD* mpd, ShuffleChain& songs, const Options& options) {
absl::Time start = absl::Now();
absl::Status status = Loop(mpd, &songs, options);
absl::Duration loop_length = absl::Now() - start;
if (!status.ok()) {
Log().Error("LOOP failed after %s with error: %s",
absl::FormatDuration(loop_length), status.ToString());
} else {
Log().Info("LOOP exited successfully after %s (probably a bug)",
absl::FormatDuration(loop_length));
}
}

} // namespace

int main(int argc, const char* argv[]) {
Expand Down Expand Up @@ -67,12 +93,23 @@ int main(int argc, const char* argv[]) {
exit(EXIT_FAILURE);
}

std::function<std::string()> pass_f = [] {
return GetPass(stdin, stdout, "mpd password: ");
log::SetOutput(std::cerr);

bool disable_reconnect = false;
std::function<std::string()> pass_f = [&disable_reconnect] {
disable_reconnect = true;
std::string pass = GetPass(stdin, stdout, "mpd password: ");
Log().InfoStr(
"Disabling reconnect support since the password was "
"provided interactively. Supply password via MPD_HOST "
"environment variable to enable automatic "
"reconnects");
return pass;
};

/* attempt to connect to MPD */
absl::StatusOr<std::unique_ptr<mpd::MPD>> mpd =
Connect(*mpd::client::Dialer(), options, pass_f);
Connect(*mpd::client::Dialer(), options, &pass_f);
if (!mpd.ok()) {
Die("Failed to connect to mpd: %s", mpd.status().ToString());
}
Expand Down Expand Up @@ -128,11 +165,39 @@ int main(int argc, const char* argv[]) {
std::cout << absl::StrFormat(" (%u songs)", number_of_songs);
}
std::cout << "." << std::endl;
} else {
if (auto status = Loop(mpd->get(), &songs, options); !status.ok()) {
Die("Failed to loop: %s", status.ToString());
exit(EXIT_SUCCESS);
return 0;
}

LoopOnce(mpd->get(), songs, options);
if (disable_reconnect) {
exit(EXIT_FAILURE);
}

absl::Time disconnect_begin = absl::Now();
while ((absl::Now() - disconnect_begin) < kMaxDisconnectedTime) {
mpd = Connect(*mpd::client::Dialer(), options, kNonInteractiveGetpass);
if (!mpd.ok()) {
Log().Error("Failed to reconnect to MPD %s, been waiting %s",
mpd.status().ToString(),
absl::FormatDuration(absl::Now() - disconnect_begin));

absl::SleepFor(kReconnectWait);
continue;
}

if (auto l = Reloader(mpd->get(), options); l.has_value()) {
(*l)->Load(&songs);
PrintChainLength(std::cout, songs);
}

LoopOnce(mpd->get(), songs, options);

// Re-set the disconnection timer after we successfully reconnect.
disconnect_begin = absl::Now();
}
Log().Error("Could not reconnect after %s, aborting.",
absl::FormatDuration(kMaxDisconnectedTime));

return 0;
exit(EXIT_FAILURE);
}
Loading

0 comments on commit f61dfcd

Please sign in to comment.