Skip to content

Commit

Permalink
Improvements, mainly DownloadClient stuff
Browse files Browse the repository at this point in the history
ref: #319
  • Loading branch information
Skidamek committed Feb 25, 2025
1 parent 1f78a71 commit 04b26fd
Show file tree
Hide file tree
Showing 32 changed files with 167 additions and 822 deletions.
2 changes: 1 addition & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies {
implementation("org.apache.logging.log4j:log4j-core:2.20.0")
implementation("com.google.code.gson:gson:2.10.1")
implementation("io.netty:netty-all:4.1.118.Final")
implementation("org.bouncycastle:bcprov-jdk18on:1.80")
// implementation("org.bouncycastle:bcprov-jdk18on:1.80")
implementation("org.bouncycastle:bcpkix-jdk18on:1.80")
implementation("com.github.luben:zstd-jni:1.5.7-1")
implementation("org.tomlj:tomlj:1.1.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import pl.skidam.automodpack_core.config.Jsons;
import pl.skidam.automodpack_core.loader.*;
import pl.skidam.automodpack_core.modpack.Modpack;
import pl.skidam.protocol.netty.NettyServer;
import pl.skidam.automodpack_core.protocol.netty.NettyServer;

import java.nio.file.Path;

public class GlobalVariables {
public static final Logger LOGGER = LogManager.getLogger("AutoModpack");
public static final String MOD_ID = "automodpack";
public static final String SECRET_REQUEST_HEADER = "AutoModpack-Secret";
public static Boolean DEBUG = false;
public static Boolean preload;
public static String MC_VERSION;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package pl.skidam.automodpack_core.auth;

import pl.skidam.protocol.NetUtils;
import pl.skidam.automodpack_core.protocol.NetUtils;

import java.net.SocketAddress;
import java.security.SecureRandom;
Expand Down Expand Up @@ -51,7 +51,7 @@ public static boolean isSecretValid(String secretStr, SocketAddress address) {
return false;

String playerUuid = playerSecretPair.getKey();
if (!GAME_CALL.canPlayerJoin(address, playerUuid)) // check if associated player is still whitelisted
if (!GAME_CALL.isPlayerAuthorized(address, playerUuid)) // check if associated player is still whitelisted
return false;

long secretLifetime = serverConfig.secretLifetime * 3600; // in seconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
import java.net.SocketAddress;

public interface GameCallService {
boolean canPlayerJoin(SocketAddress address, String id);
boolean isPlayerAuthorized(SocketAddress address, String id);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

public class NullGameCall implements GameCallService {
@Override
public boolean canPlayerJoin(SocketAddress address, String id) {
public boolean isPlayerAuthorized(SocketAddress address, String id) {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pl.skidam.protocol;
package pl.skidam.automodpack_core.protocol;

import pl.skidam.automodpack_core.auth.Secrets;
import com.github.luben.zstd.Zstd;
Expand All @@ -12,14 +12,11 @@
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;

import static pl.skidam.protocol.NetUtils.*;
import static pl.skidam.automodpack_core.protocol.NetUtils.*;

/**
* A DownloadClient that creates a pool of connections.
Expand Down Expand Up @@ -124,6 +121,7 @@ public Connection(InetSocketAddress remoteAddress, Secrets.Secret secret) throws
sslSocket.close();
throw new IOException("Invalid server certificate chain");
}

boolean validated = false;
for (Certificate cert : certs) {
if (cert instanceof X509Certificate x509Cert) {
Expand Down Expand Up @@ -161,6 +159,7 @@ public void setBusy(boolean value) {
*/
public CompletableFuture<Object> sendDownloadFile(byte[] fileHash, Path destination, IntCallback chunkCallback) {
return CompletableFuture.supplyAsync(() -> {
Exception exception = null;
try {
// Build File Request message:
// [protocolVersion][FILE_REQUEST_TYPE][secret][int: fileHash.length][fileHash]
Expand All @@ -177,9 +176,10 @@ public CompletableFuture<Object> sendDownloadFile(byte[] fileHash, Path destinat
writeProtocolMessage(payload);
return readFileResponse(destination, chunkCallback);
} catch (Exception e) {
exception = e;
throw new CompletionException(e);
} finally {
setBusy(false);
finalBlock(exception);
}
}, executor);
}
Expand All @@ -189,6 +189,7 @@ public CompletableFuture<Object> sendDownloadFile(byte[] fileHash, Path destinat
*/
public CompletableFuture<Object> sendRefreshRequest(byte[][] fileHashes) {
return CompletableFuture.supplyAsync(() -> {
Exception exception = null;
try {
// Build Refresh Request message:
// [protocolVersion][REFRESH_REQUEST_TYPE][secret][int: fileHashesCount]
Expand All @@ -211,13 +212,32 @@ public CompletableFuture<Object> sendRefreshRequest(byte[][] fileHashes) {
writeProtocolMessage(payload);
return readFileResponse(null, null);
} catch (Exception e) {
exception = e;
throw new CompletionException(e);
} finally {
setBusy(false);
finalBlock(exception);
}
}, executor);
}

private void finalBlock(Exception exception) {
// skip any remaining data
try {
while (in.available() > 0) {
in.skipBytes(in.available());
}
} catch (IOException e) {
if (exception == null) {
exception = e;
throw new CompletionException(e);
}
} finally {
if (exception == null) {
setBusy(false);
}
}
}

/**
* Compresses and writes a protocol message using Zstd.
* Message framing: [int: compressedLength][int: originalLength][compressed payload].
Expand Down Expand Up @@ -251,65 +271,66 @@ private byte[] readProtocolMessageFrame() throws IOException {
private Object readFileResponse(Path destination, IntCallback chunkCallback) throws IOException {
// Header frame
byte[] headerFrame = readProtocolMessageFrame();
DataInputStream headerIn = new DataInputStream(new ByteArrayInputStream(headerFrame));
byte version = headerIn.readByte();
byte messageType = headerIn.readByte();
if (messageType == ERROR) {
int errLen = headerIn.readInt();
byte[] errBytes = new byte[errLen];
headerIn.readFully(errBytes);
throw new IOException("Server error: " + new String(errBytes));
}
if (messageType != FILE_RESPONSE_TYPE) {
throw new IOException("Unexpected message type: " + messageType);
}
long expectedFileSize = headerIn.readLong();

long receivedBytes = 0;
OutputStream fos = null;
List<byte[]> rawData = null;
if (destination != null) {
fos = new FileOutputStream(destination.toFile());
} else {
rawData = new LinkedList<>();
}
try (DataInputStream headerIn = new DataInputStream(new ByteArrayInputStream(headerFrame))) {
byte version = headerIn.readByte();
byte messageType = headerIn.readByte();

if (messageType == ERROR) {
int errLen = headerIn.readInt();
byte[] errBytes = new byte[errLen];
headerIn.readFully(errBytes);
throw new IOException("Server error: " + new String(errBytes));
}

long receivedBytes = 0;
OutputStream fos = (destination != null) ? new FileOutputStream(destination.toFile()) : null;
List<byte[]> rawData = (fos == null) ? new LinkedList<>() : null;

// Read data frames until the expected file size is received.
while (receivedBytes < expectedFileSize) {
byte[] dataFrame = readProtocolMessageFrame();
int toWrite = dataFrame.length;
if (receivedBytes + toWrite > expectedFileSize) {
toWrite = (int)(expectedFileSize - receivedBytes);
if (messageType == END_OF_TRANSMISSION) {
if (fos != null) fos.close();
return (rawData != null) ? rawData : destination;
}
if (fos != null) {
fos.write(dataFrame, 0, toWrite);
} else {
byte[] chunk = new byte[toWrite];
System.arraycopy(dataFrame, 0, chunk, 0, toWrite);
rawData.add(chunk);

if (messageType != FILE_RESPONSE_TYPE) {
if (fos != null) fos.close();
throw new IOException("Unexpected message type: " + messageType);
}
receivedBytes += toWrite;

if (chunkCallback != null) {
chunkCallback.run(toWrite);
long expectedFileSize = headerIn.readLong();

// Read data frames until the expected file size is received.
while (receivedBytes < expectedFileSize) {
byte[] dataFrame = readProtocolMessageFrame();
int toWrite = Math.min(dataFrame.length, (int)(expectedFileSize - receivedBytes));

if (fos != null) {
fos.write(dataFrame, 0, toWrite);
} else {
byte[] chunk = Arrays.copyOfRange(dataFrame, 0, toWrite);
rawData.add(chunk);
}
receivedBytes += toWrite;

if (chunkCallback != null) {
chunkCallback.run(toWrite);
}
}
}

// Read EOT frame
byte[] eotFrame = readProtocolMessageFrame();
DataInputStream eotIn = new DataInputStream(new ByteArrayInputStream(eotFrame));
byte ver = eotIn.readByte();
byte eotType = eotIn.readByte();
if (ver != version || eotType != END_OF_TRANSMISSION) {
throw new IOException("Invalid end-of-transmission marker. Expected version " + version +
" and type " + END_OF_TRANSMISSION + ", got version " + ver + " and type " + eotType);
}
if (fos != null) fos.close();

// Read EOT frame
byte[] eotFrame = readProtocolMessageFrame();
try (DataInputStream eotIn = new DataInputStream(new ByteArrayInputStream(eotFrame))) {
byte ver = eotIn.readByte();
byte eotType = eotIn.readByte();

if (ver != version || eotType != END_OF_TRANSMISSION) {
throw new IOException("Invalid end-of-transmission marker. Expected version " + version +
" and type " + END_OF_TRANSMISSION + ", got version " + ver + " and type " + eotType);
}
}

if (fos != null) {
fos.close();
return destination;
} else {
return rawData;
return (rawData != null) ? rawData : destination;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pl.skidam.protocol;
package pl.skidam.automodpack_core.protocol;

import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pl.skidam.protocol.netty;
package pl.skidam.automodpack_core.protocol.netty;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
Expand All @@ -12,8 +12,8 @@
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import pl.skidam.automodpack_core.config.ConfigTools;
import pl.skidam.protocol.NetUtils;
import pl.skidam.protocol.netty.handler.ProtocolServerHandler;
import pl.skidam.automodpack_core.protocol.NetUtils;
import pl.skidam.automodpack_core.protocol.netty.handler.ProtocolServerHandler;
import pl.skidam.automodpack_core.utils.CustomThreadFactoryBuilder;
import pl.skidam.automodpack_core.utils.Ip;
import pl.skidam.automodpack_core.utils.ObservableMap;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package pl.skidam.protocol.netty.handler;
package pl.skidam.automodpack_core.protocol.netty.handler;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import pl.skidam.protocol.NetUtils;
import pl.skidam.protocol.netty.message.EchoMessage;
import pl.skidam.protocol.netty.message.FileRequestMessage;
import pl.skidam.protocol.netty.message.FileResponseMessage;
import pl.skidam.protocol.netty.message.RefreshRequestMessage;
import pl.skidam.automodpack_core.protocol.NetUtils;
import pl.skidam.automodpack_core.protocol.netty.message.EchoMessage;
import pl.skidam.automodpack_core.protocol.netty.message.FileRequestMessage;
import pl.skidam.automodpack_core.protocol.netty.message.FileResponseMessage;
import pl.skidam.automodpack_core.protocol.netty.message.RefreshRequestMessage;

import java.util.List;

import static pl.skidam.protocol.NetUtils.*;
import static pl.skidam.automodpack_core.protocol.NetUtils.*;

public class ProtocolMessageDecoder extends ByteToMessageDecoder {
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package pl.skidam.protocol.netty.handler;
package pl.skidam.automodpack_core.protocol.netty.handler;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import pl.skidam.protocol.netty.message.*;
import pl.skidam.automodpack_core.protocol.netty.message.*;

import static pl.skidam.protocol.NetUtils.*;
import static pl.skidam.automodpack_core.protocol.NetUtils.*;

public class ProtocolMessageEncoder extends MessageToByteEncoder<ProtocolMessage> {
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pl.skidam.protocol.netty.handler;
package pl.skidam.automodpack_core.protocol.netty.handler;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -8,7 +8,7 @@

import java.util.List;

import static pl.skidam.protocol.NetUtils.*;
import static pl.skidam.automodpack_core.protocol.NetUtils.*;

public class ProtocolServerHandler extends ByteToMessageDecoder {

Expand Down
Loading

0 comments on commit 04b26fd

Please sign in to comment.