Skip to content

Commit

Permalink
Save busy states of channels and make sure to use only free ones
Browse files Browse the repository at this point in the history
  • Loading branch information
Skidamek committed Feb 24, 2025
1 parent f3ebc28 commit 9e9c124
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,25 @@
import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicBoolean;

import static pl.skidam.automodpack_core.GlobalVariables.MOD_ID;
import static pl.skidam.automodpack_core.netty.NetUtils.MAGIC_AMMC;

public class DownloadClient extends NettyClient {
private final List<Channel> channels = new ArrayList<>();
private final AtomicInteger roundRobinIndex = new AtomicInteger(0);
private final Map<Channel, AtomicBoolean> channels = new HashMap<>(); // channel, isBusy
private final EventLoopGroup group;
private final Bootstrap bootstrap;
private final int poolSize;
private final InetSocketAddress remoteAddress;
private final SslContext sslCtx;
private final Secrets.Secret secret;
private final DownloadClient downloadClient;
private final Semaphore channelLock = new Semaphore(0);

public DownloadClient(InetSocketAddress remoteAddress, Secrets.Secret secret, int poolSize) throws InterruptedException, SSLException {
this.downloadClient = this;
this.remoteAddress = remoteAddress;
this.secret = secret;
this.poolSize = poolSize;

// Yes, we use the insecure because server uses self-signed cert and we have different way to verify the authenticity
// Via secret and fingerprint, so the encryption strength should be the same, correct me if I'm wrong, thanks
Expand All @@ -60,7 +51,7 @@ public DownloadClient(InetSocketAddress remoteAddress, Secrets.Secret secret, in
.build();

group = new NioEventLoopGroup();
bootstrap = new Bootstrap();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
Expand Down Expand Up @@ -92,7 +83,12 @@ public void secureInit(ChannelHandlerContext ctx) {

@Override
public void addChannel(Channel channel) {
channels.add(channel);
channels.put(channel, new AtomicBoolean(false));
}

@Override
public void removeChannel(Channel channel) {
channels.remove(channel);
}

@Override
Expand All @@ -110,14 +106,19 @@ public Secrets.Secret getSecret() {
* Returns a CompletableFuture that completes when the download finishes.
*/
public CompletableFuture<Object> downloadFile(byte[] fileHash, Path destination) {
// Select first not busy channel
Channel channel = channels.entrySet().stream()
.filter(entry -> !entry.getValue().get())
.findFirst()
.map(Map.Entry::getKey)
.orElseThrow(() -> new IllegalStateException("No available channels"));

// Select a channel via round-robin.
int index = roundRobinIndex.getAndIncrement();
Channel channel = channels.get(index % channels.size());
// Mark channel as busy
channels.get(channel).set(true);

// Add a new FileDownloadHandler to process this download.
FileDownloadHandler downloadHandler = new FileDownloadHandler(destination);
channel.pipeline().addLast("downloadHandler-" + index, downloadHandler);
channel.pipeline().addLast("download-handler", downloadHandler);

byte[] bsecret = Base64.getUrlDecoder().decode(secret.secret());

Expand All @@ -126,21 +127,30 @@ public CompletableFuture<Object> downloadFile(byte[] fileHash, Path destination)
channel.writeAndFlush(request);

// Return the future that will complete when the download finishes.
return downloadHandler.getDownloadFuture();
return downloadHandler.getDownloadFuture().whenComplete((result, throwable) -> {
// Mark channel as not busy
channels.get(channel).set(false);
});
}

/**
* Downloads a file by its SHA-1 hash to the specified destination.
* Returns a CompletableFuture that completes when the download finishes.
*/
public CompletableFuture<Object> requestRefresh(byte[][] fileHashes) {
// Select a channel via round-robin.
int index = roundRobinIndex.getAndIncrement();
Channel channel = channels.get(index % channels.size());
// Select first not busy channel
Channel channel = channels.entrySet().stream()
.filter(entry -> !entry.getValue().get())
.findFirst()
.map(Map.Entry::getKey)
.orElseThrow(() -> new IllegalStateException("No available channels"));

// Mark channel as busy
channels.get(channel).set(true);

// Add a new FileDownloadHandler to process this download.
FileDownloadHandler downloadHandler = new FileDownloadHandler(null);
channel.pipeline().addLast("downloadHandler-" + index, downloadHandler);
channel.pipeline().addLast("download-handler", downloadHandler);

byte[] bsecret = Base64.getUrlDecoder().decode(secret.secret());

Expand All @@ -149,14 +159,17 @@ public CompletableFuture<Object> requestRefresh(byte[][] fileHashes) {
channel.writeAndFlush(request);

// Return the future that will complete when the download finishes.
return downloadHandler.getDownloadFuture();
return downloadHandler.getDownloadFuture().whenComplete((result, throwable) -> {
// Mark channel as not busy
channels.get(channel).set(false);
});
}

/**
* Closes all channels in the pool and shuts down the event loop.
*/
public void close() {
for (Channel channel : channels) {
for (Channel channel : channels.keySet()) {
if (channel.isOpen()) {
channel.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
Expand All @@ -22,23 +21,18 @@
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;

import static pl.skidam.automodpack_core.netty.NetUtils.MAGIC_AMMC;

public class EchoClient extends NettyClient {
private final List<Channel> channels = new ArrayList<>();
private final AtomicInteger roundRobinIndex = new AtomicInteger(0);
private final EventLoopGroup group;
private final Bootstrap bootstrap;
private final InetSocketAddress remoteAddress;
private final SslContext sslCtx;
private final EchoClient echoClient;
private final Semaphore channelLock = new Semaphore(0);

public EchoClient(InetSocketAddress remoteAddress) throws InterruptedException, SSLException {
this.echoClient = this;
this.remoteAddress = remoteAddress;

// Yes, we use the insecure because server uses self-signed cert and we have different way to verify the authenticity
// Via secret and fingerprint, so the encryption strength should be the same, correct me if I'm wrong, thanks
Expand All @@ -52,12 +46,8 @@ public EchoClient(InetSocketAddress remoteAddress) throws InterruptedException,
"TLS_CHACHA20_POLY1305_SHA256"))
.build();

String[] enabledProtocols = sslCtx.newEngine(ByteBufAllocator.DEFAULT).getEnabledProtocols();
System.out.println("Enabled protocols: " + String.join(", ", enabledProtocols));
System.out.println("Secure SslContext created using cipher suites: " + String.join(", ", sslCtx.cipherSuites()));

group = new NioEventLoopGroup();
bootstrap = new Bootstrap();
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(group)
.channel(NioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
Expand Down Expand Up @@ -87,6 +77,11 @@ public void addChannel(Channel channel) {
channels.add(channel);
}

@Override
public void removeChannel(Channel channel) {
channels.remove(channel);
}

@Override
public void releaseChannel() {
channelLock.release();
Expand All @@ -102,9 +97,7 @@ public Secrets.Secret getSecret() {
* Returns a CompletableFuture that completes when the download finishes.
*/
public CompletableFuture<Void> sendEcho(byte[] secret, byte[] data) {
// Select a channel via round-robin.
int index = roundRobinIndex.getAndIncrement();
Channel channel = channels.get(index % channels.size());
Channel channel = channels.get(0);

// Build and send the file request (which carries the secret and file hash).
EchoMessage request = new EchoMessage((byte) 1, secret, data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public abstract class NettyClient {
public abstract SslContext getSslCtx();
public abstract void secureInit(ChannelHandlerContext ctx);
public abstract void addChannel(Channel channel);
public abstract void removeChannel(Channel channel);
public abstract void releaseChannel();
public abstract Secrets.Secret getSecret();
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
}
}

@Override
public void channelInactive(ChannelHandlerContext ctx) {
client.removeChannel(ctx.channel());
client.releaseChannel();
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
Expand Down

0 comments on commit 9e9c124

Please sign in to comment.