Skip to content

Commit

Permalink
merge #156
Browse files Browse the repository at this point in the history
feat: Implement reconnection mechanism for Discord
  • Loading branch information
vyfor authored Dec 31, 2024
2 parents 86f2a70 + b0de9ba commit 29aefd3
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 100 deletions.
7 changes: 7 additions & 0 deletions lua/cord/plugin/config/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
---@class CordAdvancedConfig
---@field plugin? CordAdvancedPluginConfig configuration
---@field server? CordAdvancedServerConfig configuration
---@field discord? CordAdvancedDiscordConfig configuration
---@field cursor_update_mode? string Cursor update mode
---@field variables_in_functions? boolean Whether to use variables in functions
---@field match_in_mappings? boolean Whether to use `:match()` calls in mappings to detect languages, not supported by Neovim, by matching against file extensions
Expand All @@ -72,6 +73,9 @@
---@field executable_path? string Path to the server's executable
---@field timeout? integer Timeout in milliseconds

---@class CordAdvancedDiscordConfig
---@field reconnect_interval? integer Reconnect interval in milliseconds, 0 to disable

---@alias CordVariablesConfig { [string]: string|fun(opts: CordOpts):string }

---@class CordConfig
Expand Down Expand Up @@ -153,6 +157,9 @@ M.opts = {
executable_path = nil,
timeout = 60000,
},
discord = {
reconnect_interval = 0,
},
cursor_update_mode = 'on_hold',
variables_in_functions = false,
match_in_mappings = true,
Expand Down
13 changes: 9 additions & 4 deletions lua/cord/server/event/receiver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,25 @@ end
function Handler:on_event(type, data)
local handler = self.handlers[type]
if handler then
handler(data)
if handler.oneshot then self.handlers[type] = nil end
handler.callback(data)
else
self.queue[type] = data
end
end

function Handler:register(type, callback)
function Handler:register(type, oneshot, callback)
local data = self.queue[type]
if data then
callback(data)
self.queue[type] = nil
if oneshot then return end
end

self.handlers[type] = callback
self.handlers[type] = {
oneshot = oneshot,
callback = callback,
}
end

function Handler:run()
Expand Down Expand Up @@ -70,7 +75,7 @@ function Handler:run()
end

function Handler:setup_default_handlers()
self:register('log', function(data)
self:register('log', false, function(data)
if data.level and data.message then
logger.log_raw(data.level, data.message)
if data.level == vim.log.levels.ERROR then
Expand Down
25 changes: 24 additions & 1 deletion lua/cord/server/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function M:run()

M.rx:register(
'ready',
true,
vim.schedule_wrap(function()
self.status = 'ready'
async.run(function()
Expand All @@ -86,12 +87,34 @@ function M:run()

manager:run()
M.manager = manager

M.rx:register(
'disconnect',
false,
vim.schedule_wrap(function()
self.status = 'connected'
M.manager:cleanup()
if self.config.hooks.on_disconnect then
self.config.hooks.on_disconnect()
end

M.rx:register(
'ready',
true,
vim.schedule_wrap(function()
self.status = 'ready'
logger.info 'Reconnected to Discord'
M.manager:run()
end)
)
end)
)
end)
end)
)

M.rx:run()
logger.debug 'Server initialized'
M.rx:run()
end)()
end

Expand Down
2 changes: 2 additions & 0 deletions lua/cord/server/spawn/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ M.spawn = async.wrap(function(client_id, pipe_path, exec_path)
client_id,
'-t',
require('cord.plugin.config').opts.advanced.server.timeout,
'-r',
require('cord.plugin.config').opts.advanced.discord.reconnect_interval,
},
on_stdout = function(data)
if data:match 'Ready' then resolve(false) end
Expand Down
12 changes: 12 additions & 0 deletions lua/cord/server/update/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,16 @@ M.fetch = async.wrap(function()
end)
end)

M.version = function()
local executable_path = require('cord.server.fs').get_executable_path()
local process = require 'cord.core.uv.process'

local result, err = process
.spawn({
cmd = executable_path,
args = { '-v' },
})
:get()
end

return M
24 changes: 24 additions & 0 deletions src/cli/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::error::CliError;
use crate::echo;

const DEFAULT_TIMEOUT: u64 = 60000;
const DEFAULT_RECONNECT_INTERVAL: u64 = 0;
#[cfg(target_os = "windows")]
const DEFAULT_PIPE_NAME: &str = "\\\\.\\pipe\\cord-ipc";
#[cfg(not(target_os = "windows"))]
Expand All @@ -14,6 +15,7 @@ pub struct Args {
pub pipe_name: String,
pub client_id: u64,
pub timeout: u64,
pub reconnect_interval: u64,
}

impl Args {
Expand All @@ -23,6 +25,7 @@ impl Args {
let mut pipe_name = None;
let mut client_id = None;
let mut timeout = None;
let mut reconnect_interval = None;

let mut i = 1;
while i < args.len() {
Expand Down Expand Up @@ -73,6 +76,25 @@ impl Args {
return Err(CliError::Missing("--timeout").into());
}
}
"--reconnect-interval" | "-r" => {
if i + 1 < args.len() {
match args[i + 1].parse() {
Ok(t) => reconnect_interval = Some(t),
_ => {
return Err(CliError::Invalid(
"--reconnect-interval",
"reconnect interval must be a valid u64",
)
.into())
}
}
i += 2;
} else {
return Err(
CliError::Missing("--reconnect-interval").into()
);
}
}
other => {
return Err(CliError::Unknown(other.to_string()).into());
}
Expand All @@ -84,6 +106,8 @@ impl Args {
.unwrap_or_else(|| DEFAULT_PIPE_NAME.to_string()),
client_id: client_id.ok_or(CliError::Missing("--client-id"))?,
timeout: timeout.unwrap_or(DEFAULT_TIMEOUT),
reconnect_interval: reconnect_interval
.unwrap_or(DEFAULT_RECONNECT_INTERVAL),
})
}
}
31 changes: 20 additions & 11 deletions src/cord.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::Duration;

use crate::error::CordErrorKind;
Expand Down Expand Up @@ -27,7 +27,7 @@ use crate::util::logger::{LogLevel, Logger};
pub struct Cord {
pub config: Config,
pub session_manager: Arc<SessionManager>,
pub rich_client: Arc<RichClient>,
pub rich_client: Arc<RwLock<RichClient>>,
pub pipe: PipeServer,
pub tx: Sender<Message>,
pub rx: Receiver<Message>,
Expand All @@ -43,7 +43,7 @@ impl Cord {
let (tx, rx) = mpsc::channel::<Message>();
let session_manager = Arc::new(SessionManager::default());
let rich_client = match RichClient::connect(config.client_id) {
Ok(client) => Arc::new(client),
Ok(client) => Arc::new(RwLock::new(client)),
Err(_) => {
return Err(crate::error::CordError::new(
CordErrorKind::Io,
Expand Down Expand Up @@ -114,20 +114,22 @@ impl Cord {

/// Starts RPC with Discord.
pub fn start_rpc(&mut self) -> crate::Result<()> {
self.rich_client.handshake()?;
let tx = self.tx.clone();

Arc::get_mut(&mut self.rich_client)
.expect("Failed to start read thread")
.start_read_thread(tx.clone())?;
let mut rich_client = self
.rich_client
.write()
.expect("Failed to lock rich client");
rich_client.handshake()?;
rich_client.start_read_thread(self.tx.clone())?;

Ok(())
}

/// Cleans up before shutdown.
pub fn cleanup(&mut self) {
if let Some(client) = Arc::get_mut(&mut self.rich_client) {
client.close();
if let Ok(client) = client.get_mut() {
client.close();
}
}

self.pipe.stop();
Expand All @@ -146,15 +148,22 @@ pub struct Config {
pub pipe_name: String,
pub client_id: u64,
pub timeout: u64,
pub reconnect_interval: u64,
}

impl Config {
/// Creates a new configuration.
pub fn new(pipe_name: String, client_id: u64, timeout: u64) -> Self {
pub fn new(
pipe_name: String,
client_id: u64,
timeout: u64,
reconnect_interval: u64,
) -> Self {
Self {
pipe_name,
client_id,
timeout,
reconnect_interval,
}
}
}
41 changes: 39 additions & 2 deletions src/ipc/discord/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::Sender;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;

use crate::messages::message::Message;
use crate::presence::packet::Packet;
Expand All @@ -16,14 +18,15 @@ use crate::protocol::json::Json;
pub struct RichClient {
pub client_id: u64,
#[cfg(target_os = "windows")]
pub pipe: Option<std::sync::Arc<std::fs::File>>,
pub pipe: Option<Arc<std::fs::File>>,
#[cfg(not(target_os = "windows"))]
pub read_pipe: Option<std::os::unix::net::UnixStream>,
#[cfg(not(target_os = "windows"))]
pub write_pipe: Option<std::os::unix::net::UnixStream>,
pub pid: u32,
pub is_ready: AtomicBool,
pub thread_handle: Option<JoinHandle<()>>,
pub is_reconnecting: Arc<AtomicBool>,
}

/// Defines methods for connecting and closing the client.
Expand Down Expand Up @@ -69,4 +72,38 @@ impl RichClient {
_ => Ok(()),
}
}

/// Reconnects to Discord with exponential backoff.
pub fn reconnect(
&mut self,
initial_interval: u64,
tx: Sender<Message>,
) -> crate::Result<()> {
self.is_reconnecting.store(true, Ordering::SeqCst);
self.close();

std::thread::sleep(Duration::from_millis(500));

while self.is_reconnecting.load(Ordering::SeqCst) {
if let Ok(mut client) = Self::connect(self.client_id) {
if client.handshake().is_ok() {
*self = client;
if let Err(e) = self.start_read_thread(tx) {
self.is_reconnecting.store(false, Ordering::SeqCst);
return Err(e);
};

break;
} else {
client.close();
}
};

std::thread::sleep(Duration::from_millis(initial_interval));
}

self.is_reconnecting.store(false, Ordering::SeqCst);

Ok(())
}
}
35 changes: 17 additions & 18 deletions src/ipc/discord/platform/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io::{self, Read, Write};
use std::net::Shutdown;
use std::os::unix::net::UnixStream;
use std::sync::mpsc::Sender;
use std::sync::Arc;

use crate::ipc::discord::client::{Connection, RichClient};
use crate::ipc::discord::error::DiscordError;
Expand Down Expand Up @@ -52,6 +53,7 @@ impl Connection for RichClient {
pid: std::process::id(),
is_ready: false.into(),
thread_handle: None,
is_reconnecting: Arc::new(false.into()),
});
}
Err(e) => match e.kind() {
Expand Down Expand Up @@ -169,24 +171,21 @@ impl Connection for RichClient {
}

fn write(&self, opcode: u32, data: Option<&[u8]>) -> crate::Result<()> {
self.write_pipe.as_ref().map_or(
Err(DiscordError::PipeNotFound.into()),
|mut pipe| {
let payload = match data {
Some(packet) => {
let mut payload =
utils::encode(opcode, packet.len() as u32);
payload.extend_from_slice(packet);
payload
}
None => utils::encode(opcode, 0),
};

match pipe.write_all(&payload) {
Ok(_) => Ok(()),
Err(e) => Err(DiscordError::Io(e).into()),
self.write_pipe.as_ref().map_or(Ok(()), |mut pipe| {
let payload = match data {
Some(packet) => {
let mut payload =
utils::encode(opcode, packet.len() as u32);
payload.extend_from_slice(packet);
payload
}
},
)
None => utils::encode(opcode, 0),
};

match pipe.write_all(&payload) {
Ok(_) => Ok(()),
Err(e) => Err(DiscordError::Io(e).into()),
}
})
}
}
Loading

0 comments on commit 29aefd3

Please sign in to comment.