Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix backend messaging #117

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ private CachedConfig() { }

private long sourceCacheTime = new TimeUtil.Time(6L, TimeUnit.HOURS).getMillis();

private boolean messagingRedundancy = true;

public boolean getMessagingRedundancy() { return messagingRedundancy; }

public long getSourceCacheTime() { return sourceCacheTime; }

private long mcleaksCacheTime = new TimeUtil.Time(1L, TimeUnit.DAYS).getMillis();
Expand Down Expand Up @@ -148,6 +152,11 @@ public CachedConfig.Builder messaging(@NotNull List<@NotNull MessagingService> v
return this;
}

public @NotNull CachedConfig.Builder messagingRedundancy(boolean value) {
values.messagingRedundancy = value;
return this;
}

@NotNull
public CachedConfig.Builder sourceCacheTime(@NotNull TimeUtil.Time value) {
if (value.getMillis() <= 0L) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,17 @@ public static <M extends LocalizedCommandSender<M, B>, B> void reloadConfig(

AlgorithmMethod vpnAlgorithmMethod = getVpnAlgorithmMethod(config, debug, console);

boolean messagingRedundancy = config.node("messaging", "settings", "redundancy").getBoolean(true);
if (debug) {
console.sendMessage("<c2>Messaging systems are redundant:</c2> <c1>" + messagingRedundancy + "</c1>");
}

CachedConfig cachedConfig = CachedConfig.builder()
.debug(debug)
.language(language)
.storage(getStorage(config, dataDirectory, debug, console))
.messaging(getMessaging(config, serverId, messagingHandler, new File(dataDirectory, "packets"), debug, console))
.messagingRedundancy(messagingRedundancy)
.sourceCacheTime(getSourceCacheTime(config, debug, console))
.mcleaksCacheTime(getMcLeaksCacheTime(config, debug, console))
.cacheTime(getCacheTime(config, debug, console))
Expand Down Expand Up @@ -459,7 +465,7 @@ private static <M extends LocalizedCommandSender<M, B>, B> MessagingService getM
.getString("/") + "</c1>");
}
try {
return RabbitMQMessagingService.builder(name, serverId, handler, packetDirectory)
return RabbitMQMessagingService.builder(name, serverId, handler, poolSettings.delay, packetDirectory)
.url(url.address, url.port, connectionNode.node("v-host").getString("/"))
.credentials(connectionNode.node("username").getString("guest"), connectionNode.node("password").getString("guest"))
.timeout((int) poolSettings.timeout)
Expand All @@ -475,7 +481,7 @@ private static <M extends LocalizedCommandSender<M, B>, B> MessagingService getM
console.sendMessage("<c2>Creating engine</c2> <c1>" + name + "</c1> <c2>of type redis with address</c2> <c1>" + url.getAddress() + ":" + url.getPort() + "</c1>");
}
try {
return RedisMessagingService.builder(name, serverId, handler, packetDirectory)
return RedisMessagingService.builder(name, serverId, handler, poolSettings.delay, packetDirectory)
.url(url.address, url.port)
.credentials(connectionNode.node("password").getString(""))
.poolSize(poolSettings.minPoolSize, poolSettings.maxPoolSize)
Expand All @@ -492,7 +498,7 @@ private static <M extends LocalizedCommandSender<M, B>, B> MessagingService getM
console.sendMessage("<c2>Creating engine</c2> <c1>" + name + "</c1> <c2>of type NATS with address</c2> <c1>" + url.getAddress() + ":" + url.getPort() + "</c1>");
}
try {
return NATSMessagingService.builder(name, serverId, handler, packetDirectory)
return NATSMessagingService.builder(name, serverId, handler, poolSettings.delay, packetDirectory)
.url(url.address, url.port)
.credentials(connectionNode.node("file").getString(""))
.life((int) poolSettings.timeout)
Expand Down Expand Up @@ -817,7 +823,7 @@ private static <M extends LocalizedCommandSender<M, B>, B> CommentedConfiguratio
}

private static class AddressPort {
private final @NotNull String address;
private final String address;
private final int port;

public <M extends LocalizedCommandSender<M, B>, B> AddressPort(
Expand All @@ -841,8 +847,7 @@ public <M extends LocalizedCommandSender<M, B>, B> AddressPort(
this.port = p;
}

@NotNull
public String getAddress() { return address; }
public @NotNull String getAddress() { return address; }

public int getPort() { return port; }
}
Expand All @@ -852,6 +857,7 @@ private static class PoolSettings {
private final int maxPoolSize;
private final long maxLifetime;
private final long timeout;
private final long delay;

public PoolSettings(ConfigurationNode settingsNode) {
minPoolSize = settingsNode.node("min-idle").getInt();
Expand All @@ -868,6 +874,13 @@ public PoolSettings(ConfigurationNode settingsNode) {
t = new TimeUtil.Time(5L, TimeUnit.SECONDS);
}
timeout = t.getMillis();

String delayStr = settingsNode.node("delay").getString("1second");
t = delayStr.trim().equals("0") ? new TimeUtil.Time(0L, TimeUnit.SECONDS) : TimeUtil.getTime(delayStr);
if (t == null) {
t = new TimeUtil.Time(1L, TimeUnit.SECONDS);
}
delay = t.getMillis();
}

public int getMinPoolSize() { return minPoolSize; }
Expand All @@ -877,5 +890,7 @@ public PoolSettings(ConfigurationNode settingsNode) {
public long getMaxLifetime() { return maxLifetime; }

public long getTimeout() { return timeout; }

public long getDelay() { return delay; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ public static void conformVersion(
if (config.node("version").getDouble() == 5.1d) {
to52(config);
}
if (config.node("version").getDouble() == 5.2d) {
to53(config);
}

if (config.node("version").getDouble() != oldVersion) {
File backupFile = new File(fileOnDisk.getParent(), fileOnDisk.getName() + ".bak");
Expand Down Expand Up @@ -715,4 +718,14 @@ private static void to52(@NotNull CommentedConfigurationNode config) throws Seri
// Version
config.node("version").set(5.2d);
}

private static void to53(@NotNull CommentedConfigurationNode config) throws SerializationException {
// Add messaging->settings->delay
config.node("messaging", "settings", "delay").set("1second");
// Add messaging->settings->redundancy
config.node("messaging", "settings", "redundancy").set(Boolean.TRUE);

// Version
config.node("version").set(5.3d);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import me.egg82.antivpn.config.ConfigUtil;
import me.egg82.antivpn.core.Pair;
import me.egg82.antivpn.locale.LocaleUtil;
import me.egg82.antivpn.locale.MessageKey;
import me.egg82.antivpn.logging.GELFLogger;
Expand All @@ -14,6 +15,7 @@
import me.egg82.antivpn.messaging.packets.Packet;
import me.egg82.antivpn.messaging.packets.server.InitializationPacket;
import me.egg82.antivpn.messaging.packets.server.PacketVersionPacket;
import me.egg82.antivpn.services.CollectionProvider;
import me.egg82.antivpn.utils.MathUtil;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand All @@ -26,6 +28,8 @@
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.text.DecimalFormat;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -49,15 +53,28 @@ public abstract class AbstractMessagingService implements MessagingService {
protected final File sentPacketDirectory;
protected final File receivedPacketDirectory;

protected AbstractMessagingService(@NotNull String name, @NotNull File packetDirectory) {
protected final long startupDelay;

protected AbstractMessagingService(@NotNull String name, long startupDelay, @NotNull File packetDirectory) {
this.name = name;
this.sentPacketDirectory = new File(packetDirectory, "sent");
this.receivedPacketDirectory = new File(packetDirectory, "received");
this.startupDelay = startupDelay;
}

@Override
public @NotNull String getName() { return name; }

@Override
public void flushPacketQueue(@NotNull UUID forServer) {
CollectionProvider.getPacketProcessingQueue().computeIfPresent(forServer, (k, v) -> {
for (Pair<UUID, Packet> p : v) {
handler.handlePacket(p.getT1(), name, p.getT2());
}
return null;
});
}

private static final double TOLERANCE = 1.1; // Compression ratio tolerance. Determines when compression should happen

protected final byte @NotNull [] compressData(@Nullable ByteBuf data) throws IOException {
Expand Down Expand Up @@ -89,14 +106,14 @@ protected AbstractMessagingService(@NotNull String name, @NotNull File packetDir
nd.readBytes(out, 1, uncompressedBytes);

if (ConfigUtil.getDebugOrFalse()) {
logger.info("Sent (no) compression: " + out.length + "/" + uncompressedBytes + " (" + ratioFormat.format((double) uncompressedBytes / (double) out.length) + ")");
logger.debug("Sent (no) compression: " + out.length + "/" + uncompressedBytes + " (" + ratioFormat.format((double) uncompressedBytes / (double) out.length) + ")");
}

return out;
}

if (ConfigUtil.getDebugOrFalse()) {
logger.info("Sent compression: " + (compressedBytes + 5) + "/" + uncompressedBytes + " (" + ratioFormat.format((double) uncompressedBytes / (double) (compressedBytes + 5)) + ")");
logger.debug("Sent compression: " + (compressedBytes + 5) + "/" + uncompressedBytes + " (" + ratioFormat.format((double) uncompressedBytes / (double) (compressedBytes + 5)) + ")");
}

dest.put(0, (byte) 0x01);
Expand Down Expand Up @@ -126,7 +143,7 @@ protected AbstractMessagingService(@NotNull String name, @NotNull File packetDir
data.readBytes(retVal);

if (ConfigUtil.getDebugOrFalse()) {
logger.info("Received (no) compression: " + compressedBytes + "/" + (compressedBytes - 1) + " (" + ratioFormat.format((double) (compressedBytes - 1) / (double) compressedBytes) + ")");
logger.debug("Received (no) compression: " + compressedBytes + "/" + (compressedBytes - 1) + " (" + ratioFormat.format((double) (compressedBytes - 1) / (double) compressedBytes) + ")");
}

return retVal;
Expand All @@ -147,7 +164,7 @@ protected AbstractMessagingService(@NotNull String name, @NotNull File packetDir
}

if (ConfigUtil.getDebugOrFalse()) {
logger.info("Received compression: " + compressedBytes + "/" + uncompressedBytes + " (" + ratioFormat.format((double) uncompressedBytes / (double) compressedBytes) + ")");
logger.debug("Received compression: " + compressedBytes + "/" + uncompressedBytes + " (" + ratioFormat.format((double) uncompressedBytes / (double) compressedBytes) + ")");
}

dest.rewind();
Expand Down Expand Up @@ -278,17 +295,40 @@ protected static boolean hasVersion(@NotNull Packet packet) {
return true;
}

int i = 0;
if (packet instanceof MultiPacket) {
MultiPacket mult = (MultiPacket) packet;
for (Packet p : mult.getPackets()) {
if (p instanceof InitializationPacket || p instanceof PacketVersionPacket) {
if (i > 0) {
reorder(mult);
}
return true;
}
i++;
}
}
return false;
}

private static void reorder(@NotNull MultiPacket packet) {
// TODO: There is definitely a more efficient way to do this, probably using streams

Set<Packet> removedPackets = new LinkedHashSet<>();
Set<Packet> keptPackets = new LinkedHashSet<>();

for (Packet p : packet.getPackets()) {
if (p instanceof InitializationPacket || p instanceof PacketVersionPacket) {
removedPackets.add(p);
} else {
keptPackets.add(p);
}
}

removedPackets.addAll(keptPackets);
packet.setPackets(removedPackets);
}

private void printBytes(@NotNull ByteBuf buffer) {
StringBuilder sb = new StringBuilder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ public interface MessagingService {
boolean isClosed();

void sendPacket(@NotNull UUID messageId, @NotNull Packet packet) throws IOException, TimeoutException;

void flushPacketQueue(@NotNull UUID forServer);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,23 @@
import io.nats.client.Options;
import io.netty.buffer.ByteBuf;
import me.egg82.antivpn.config.ConfigUtil;
import me.egg82.antivpn.core.Pair;
import me.egg82.antivpn.locale.LocaleUtil;
import me.egg82.antivpn.locale.MessageKey;
import me.egg82.antivpn.messaging.handler.MessagingHandler;
import me.egg82.antivpn.messaging.packets.Packet;
import me.egg82.antivpn.messaging.packets.server.KeepAlivePacket;
import me.egg82.antivpn.messaging.packets.server.PacketVersionRequestPacket;
import me.egg82.antivpn.services.CollectionProvider;
import me.egg82.antivpn.utils.PacketUtil;
import org.jetbrains.annotations.NotNull;

import java.io.File;
import java.io.IOException;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

Expand All @@ -29,8 +35,8 @@ public class NATSMessagingService extends AbstractMessagingService {

private static final String SUBJECT_NAME = "avpn-data";

private NATSMessagingService(@NotNull String name, @NotNull File packetDirectory) {
super(name, packetDirectory);
private NATSMessagingService(@NotNull String name, long startupDelay, @NotNull File packetDirectory) {
super(name, startupDelay, packetDirectory);
}

@Override
Expand All @@ -56,15 +62,16 @@ public void close() {
@NotNull String name,
@NotNull UUID serverId,
@NotNull MessagingHandler handler,
long startupDelay,
@NotNull File packetDirectory
) { return new Builder(name, serverId, handler, packetDirectory); }
) { return new Builder(name, serverId, handler, startupDelay, packetDirectory); }

public static class Builder {
private final NATSMessagingService service;
private final Options.Builder config = new Options.Builder();

public Builder(@NotNull String name, @NotNull UUID serverId, @NotNull MessagingHandler handler, @NotNull File packetDirectory) {
service = new NATSMessagingService(name, packetDirectory);
public Builder(@NotNull String name, @NotNull UUID serverId, @NotNull MessagingHandler handler, long startupDelay, @NotNull File packetDirectory) {
service = new NATSMessagingService(name, startupDelay, packetDirectory);
service.serverId = serverId;
service.serverIdString = serverId.toString();
ByteBuf buffer = alloc.buffer(16, 16);
Expand Down Expand Up @@ -102,15 +109,26 @@ public Builder(@NotNull String name, @NotNull UUID serverId, @NotNull MessagingH
public @NotNull NATSMessagingService build() throws IOException, InterruptedException {
service.connection = Nats.connect(config.build());
// Indefinite subscription
subscribe();
if (service.startupDelay == 0L) {
subscribe();
} else {
CompletableFuture.runAsync(() -> {
try {
Thread.sleep(service.startupDelay);
} catch (InterruptedException ex) {
service.logger.error(ex.getClass().getName() + ": " + ex.getMessage(), ex);
Thread.currentThread().interrupt();
}
}).thenRun(this::subscribe);
}
return service;
}

private void subscribe() {
service.dispatcher = service.connection.createDispatcher(message -> {
String subject = message.getSubject();
if (ConfigUtil.getDebugOrFalse()) {
service.logger.info("Got message from subject: " + subject);
service.logger.debug("Got message from subject: " + subject);
}

try {
Expand Down Expand Up @@ -170,8 +188,33 @@ private void handleMessage(byte @NotNull [] body) throws IOException {
return;
}

if (packetVersion == -1 && packet instanceof KeepAlivePacket) {
// Don't send warning
return;
}

if (packetVersion == -1 && !hasVersion(packet)) {
service.logger.warn("Server " + sender + " packet version is unknown, and packet type is of " + packet.getClass().getName() + ". Skipping packet.");
// There's a potential race condition here with double-sending a request, but it doesn't really matter
ByteBuf finalData = data;
CollectionProvider.getPacketProcessingQueue().compute(sender, (k, v) -> {
if (v == null) {
v = new CopyOnWriteArrayList<>();
}

if (v.isEmpty()) {
if (packet.verifyFullRead(finalData)) {
v.add(new Pair<>(messageId, packet));
}
PacketUtil.queuePacket(new PacketVersionRequestPacket(sender, service.serverId));
} else {
if (packet.verifyFullRead(finalData)) {
v.add(new Pair<>(messageId, packet));
}
}

return v;
});
return;
}

Expand Down
Loading