Skip to content

Commit

Permalink
Java NIO performance and latency enhancements (#183)
Browse files Browse the repository at this point in the history
* Enhance performance throughput of nio java impl

* Remove unused running ops bits

* Make linter happy by running lint-fix
  • Loading branch information
baluchicken authored Jun 6, 2023
1 parent 2bfc8aa commit 8aed01d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,8 @@ static int getSelectedKeyId(long selectedKey) {
return (int) selectedKey;
}

static int getAsyncRunningOpsFromSelectedKey(long selectedKey) {
return (int) (selectedKey >> 40 & 0xFF);
}

private final nasp.Selector selector = new nasp.Selector();
private final Map<Integer, SelectionKeyImpl> selectionKeyTable = new HashMap<>();
private final Map<Integer, Integer> runningAsyncOps = new HashMap<>();

protected NaspSelector(SelectorProvider sp) {
super(sp);
Expand Down Expand Up @@ -77,8 +72,6 @@ protected int doSelect(Consumer<SelectionKey> action, long timeout) throws IOExc
long selectedKey = selectedKeys.getLong();
int selectedKeyId = getSelectedKeyId(selectedKey);

updateRunningAsyncOps(selectedKeyId, getAsyncRunningOpsFromSelectedKey(selectedKey));

SelectionKeyImpl selectionKey = selectionKeyTable.get(selectedKeyId);
if (selectionKey != null) {
if (selectionKey.isValid()) {
Expand All @@ -90,32 +83,6 @@ protected int doSelect(Consumer<SelectionKey> action, long timeout) throws IOExc
return numKeysUpdated;
}

private void updateRunningAsyncOps(int selectedKeyId, int updateOps) {
int runningOps = 0;
Integer temp = runningAsyncOps.get(selectedKeyId);
if (temp != null) {
runningOps = temp;
}
if ((updateOps & SelectionKey.OP_READ) != 0) {
if ((runningOps & SelectionKey.OP_READ) == 0) {
runningAsyncOps.put(selectedKeyId, runningOps | SelectionKey.OP_READ);
}
} else {
if ((runningOps & SelectionKey.OP_READ) != 0) {
runningAsyncOps.put(selectedKeyId, runningOps & ~SelectionKey.OP_READ);
}
}
if ((updateOps & SelectionKey.OP_WRITE) != 0) {
if ((runningOps & SelectionKey.OP_WRITE) == 0) {
runningAsyncOps.put(selectedKeyId, runningOps | SelectionKey.OP_WRITE);
}
} else {
if ((runningOps & SelectionKey.OP_WRITE) != 0) {
runningAsyncOps.put(selectedKeyId, runningOps & ~SelectionKey.OP_WRITE);
}
}
}

private byte[] nativeSelect(long timeout) {
return selector.select(timeout);
}
Expand All @@ -130,9 +97,6 @@ public Selector wakeup() {
protected void implRegister(SelectionKeyImpl ski) {
super.implRegister(ski);
selectionKeyTable.put(ski.hashCode(), ski);

int interestOps = ski.interestOps();
handleAsyncOps(ski, interestOps);
}

@Override
Expand All @@ -143,84 +107,56 @@ protected void implClose() throws IOException {
@Override
protected void implDereg(SelectionKeyImpl ski) throws IOException {
selectionKeyTable.remove(ski.hashCode());
runningAsyncOps.remove(ski.hashCode());
}

@Override
protected void setEventOps(SelectionKeyImpl ski) {
int opsDiff = getAsyncOpsDiff(ski);

handleAsyncOps(ski, opsDiff);
}

/**
* Returns the diff between running async ops and the list the interest ops registered in the provided ski
*
* @param ski specifies the list of ops we are interested in
* @return the diff between running async ops and the list the interest ops
*/
protected int getAsyncOpsDiff(SelectionKeyImpl ski) {
if (ski == null)
return 0;

int interestOps = ski.interestOps();
Integer runningOps = runningAsyncOps.get(ski.hashCode());
if (runningOps == null || runningOps == 0) {
return interestOps;
if ((interestOps & SelectionKey.OP_WRITE) == 0)
{
selector.unregisterWriter(ski.hashCode());
}

int mask = interestOps ^ runningOps;

return interestOps & mask;
handleAsyncOps(ski);
}

protected void handleAsyncOps(SelectionKeyImpl ski, int opsDiff) {
int selectedKeyId = ski.hashCode();
int runningOps = 0;
if (runningAsyncOps.containsKey(selectedKeyId)) {
runningOps = runningAsyncOps.get(selectedKeyId);
}
protected void handleAsyncOps(SelectionKeyImpl ski) {

if (ski.channel() instanceof NaspServerSocketChannel naspServerSockChan) {
handleAsyncOps(naspServerSockChan, ski.hashCode(), runningOps, opsDiff);
handleAsyncOps(naspServerSockChan, ski.hashCode(), ski.interestOps());
return;
}

if (ski.channel() instanceof NaspSocketChannel naspSockChan) {
handleAsyncOps(naspSockChan, ski.hashCode(), runningOps, opsDiff);
handleAsyncOps(naspSockChan, ski.hashCode(), ski.interestOps());
}
}

protected void handleAsyncOps(NaspSocketChannel naspSocketChannel, int selectedKeyId, int runningOps, int opsDiff) {
if ((opsDiff & SelectionKey.OP_READ) != 0) {
protected void handleAsyncOps(NaspSocketChannel naspSocketChannel, int selectedKeyId, int interestedOps) {
if ((interestedOps & SelectionKey.OP_READ) != 0) {
naspSocketChannel.getConnection().startAsyncRead(selectedKeyId, selector);
runningAsyncOps.put(selectedKeyId, runningOps | SelectionKey.OP_READ);
}

if ((opsDiff & SelectionKey.OP_WRITE) != 0) {
if ((interestedOps & SelectionKey.OP_WRITE) != 0) {
naspSocketChannel.getConnection().startAsyncWrite(selectedKeyId, selector);
runningAsyncOps.put(selectedKeyId, runningOps | SelectionKey.OP_WRITE);
}

if ((opsDiff & SelectionKey.OP_CONNECT) != 0) {
if ((interestedOps & SelectionKey.OP_CONNECT) != 0) {
//This could happen if we are servers not clients
if (naspSocketChannel.getNaspTcpDialer() != null) {
InetSocketAddress address = naspSocketChannel.getAddress();
if (address != null) {
naspSocketChannel.getNaspTcpDialer().startAsyncDial(selectedKeyId, selector,
address.getHostString(), address.getPort());
runningAsyncOps.put(selectedKeyId, runningOps | SelectionKey.OP_CONNECT);
} else {
naspSocketChannel.setSelector(this);
}
}
}
}

protected void handleAsyncOps(NaspServerSocketChannel naspServerSocketChannel, int selectedKeyId, int runningOps, int opsDiff) {
if ((opsDiff & SelectionKey.OP_ACCEPT) != 0) {
protected void handleAsyncOps(NaspServerSocketChannel naspServerSocketChannel, int selectedKeyId, int interestedOps) {
if ((interestedOps & SelectionKey.OP_ACCEPT) != 0) {
naspServerSocketChannel.socket().getNaspTcpListener().startAsyncAccept(selectedKeyId, selector);
runningAsyncOps.put(selectedKeyId, runningOps | SelectionKey.OP_ACCEPT);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ public int getFDVal() {
throw new UnsupportedOperationException();
}

//POLLNVAL:32
//POLLERR:8
//POLLHUP:16
//POLLIN:1
//POLLOUT:4
public boolean translateReadyOps(int ops, int initialOps, SelectionKeyImpl ski) {
int intOps = ski.nioInterestOps();
int oldOps = ski.nioReadyOps();
Expand Down
82 changes: 30 additions & 52 deletions experimental/java/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ type Address struct {
}

// SelectedKey holds the data of a Selected Key formatted on 64 bits as follows:
// |----16bits---|-----------------------------8bits------------|---------8bits--------|----32bits-----|
// | unused |read/write running ops status for selected key|selected key ready ops|selected key id|
// |----24bits---|---------8bits--------|----32bits-----|
// | unused |selected key ready ops|selected key id|
type SelectedKey uint64

func NewSelectedKey(operation ReadyOps, id uint32) SelectedKey {
Expand Down Expand Up @@ -155,7 +155,10 @@ func (s *Selector) Select(timeoutMs int64) []byte {
s.readableKeys.Range(func(key, value any) bool {
check := value.(func() bool)
if check() {
selected[uint32(key.(int32))] |= NewSelectedKey(OP_READ, uint32(key.(int32)))
select {
case s.queue <- NewSelectedKey(OP_READ, uint32(key.(int32))):
default:
}
} else if v, _ := s.readInProgress.Load(key.(int32)); v != nil && !v.(bool) {
s.unregisterReader(key.(int32))
}
Expand All @@ -166,36 +169,17 @@ func (s *Selector) Select(timeoutMs int64) []byte {
s.writableKeys.Range(func(key, value any) bool {
check := value.(func() bool)
if check() {
selected[uint32(key.(int32))] |= NewSelectedKey(OP_WRITE, uint32(key.(int32)))
select {
case s.queue <- NewSelectedKey(OP_WRITE, uint32(key.(int32))):
default:
}
}
return true
})

if (timeoutMs == 0) && len(s.queue) == 0 && len(selected) == 0 {
if (timeoutMs == 0) && len(s.queue) == 0 {
return nil
}
if len(s.queue) == 0 && len(selected) != 0 {
b := make([]byte, len(selected)*8)
i := 0
for k, v := range selected {
// add current running ops to selected key
var runningOps uint64
readInProgress, _ := s.readInProgress.Load(int32(k))
//nolint:forcetypeassert
if readInProgress != nil && readInProgress.(bool) {
runningOps |= uint64(OP_READ)
}
writeInProgress, _ := s.writeInProgress.Load(int32(k))
//nolint:forcetypeassert
if writeInProgress != nil && writeInProgress.(bool) {
runningOps |= uint64(OP_WRITE)
}

binary.BigEndian.PutUint64(b[i*8:], uint64(v)|(runningOps<<40))
i++
}
return b
}

select {
case e := <-s.queue:
Expand All @@ -214,21 +198,8 @@ func (s *Selector) Select(timeoutMs int64) []byte {

b := make([]byte, len(selected)*8)
i := 0
for k, v := range selected {
// add current running ops to selected key
var runningOps uint64
readInProgress, _ := s.readInProgress.Load(int32(k))
//nolint:forcetypeassert
if readInProgress != nil && readInProgress.(bool) {
runningOps |= uint64(OP_READ)
}
writeInProgress, _ := s.writeInProgress.Load(int32(k))
//nolint:forcetypeassert
if writeInProgress != nil && writeInProgress.(bool) {
runningOps |= uint64(OP_WRITE)
}

binary.BigEndian.PutUint64(b[i*8:], uint64(v)|(runningOps<<40))
for _, v := range selected {
binary.BigEndian.PutUint64(b[i*8:], uint64(v))
i++
}
return b
Expand All @@ -241,9 +212,8 @@ func (s *Selector) registerWriter(selectedKeyId int32, check func() bool) {
s.writableKeys.Store(selectedKeyId, check)
}

func (s *Selector) unregisterWriter(selectedKeyId int32) {
func (s *Selector) UnregisterWriter(selectedKeyId int32) {
s.writableKeys.Delete(selectedKeyId)
s.writeInProgress.Delete(selectedKeyId)
}

func (s *Selector) registerReader(selectedKeyId int32, check func() bool) {
Expand Down Expand Up @@ -417,6 +387,12 @@ func (c *Connection) StartAsyncRead(selectedKeyId int32, selector *Selector) {
go func() {
tempBuffer := make([]byte, 1024)
for {
if c.readBufferLen.Load() > 0 {
select {
case selector.queue <- NewSelectedKey(OP_READ, uint32(selectedKeyId)):
default:
}
}
num, err := c.Read(tempBuffer)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) ||
Expand Down Expand Up @@ -451,7 +427,7 @@ func (c *Connection) StartAsyncRead(selectedKeyId int32, selector *Selector) {
logger.V(1).Info("received 0 bytes on connection")
}
}

selector.readInProgress.Delete(selectedKeyId)
logger.V(1).Info("StartAsyncRead finished")
}()
}
Expand All @@ -462,7 +438,7 @@ func (c *Connection) AsyncRead(b []byte) (int32, error) {
}
c.readLock.Lock()
defer c.readLock.Unlock()
max := int(math.Min(float64(c.readBuffer.Len()), float64(len(b))))
max := int(math.Min(float64(c.readBufferLen.Load()), float64(len(b))))
n, err := c.readBuffer.Read(b[:max])
c.readBufferLen.Add(int32(-n))
return int32(n), err
Expand All @@ -476,11 +452,6 @@ func (c *Connection) StartAsyncWrite(selectedKeyId int32, selector *Selector) {
logger := c.logger.WithValues(logCtx...)
logger.V(1).Info("StartAsyncWrite invoked")

//nolint:forcetypeassert
if v, _ := selector.writeInProgress.Swap(selectedKeyId, true); v != nil && v.(bool) {
return
}

selector.registerWriter(selectedKeyId, func() bool {
b := len(c.writeChannel) < MAX_WRITE_BUFFERS
if !b {
Expand All @@ -489,6 +460,11 @@ func (c *Connection) StartAsyncWrite(selectedKeyId int32, selector *Selector) {
return b
})

//nolint:forcetypeassert
if v, _ := selector.writeInProgress.Swap(selectedKeyId, true); v != nil && v.(bool) {
return
}

go func() {
out:
for buff := range c.writeChannel {
Expand All @@ -514,7 +490,8 @@ func (c *Connection) StartAsyncWrite(selectedKeyId int32, selector *Selector) {
}
}
}
selector.unregisterWriter(selectedKeyId)
selector.UnregisterWriter(selectedKeyId)
selector.writeInProgress.Delete(selectedKeyId)
logger.V(1).Info("StartAsyncWrite finished")
}()
}
Expand Down Expand Up @@ -586,6 +563,7 @@ func NewTCPDialer() (*TCPDialer, error) {
integrationHandler.cancel()
return nil, err
}
// dialer := &net.Dialer{}

return &TCPDialer{
dialer: dialer,
Expand Down

0 comments on commit 8aed01d

Please sign in to comment.