Skip to content

Commit

Permalink
bug: fix the inconsistent behaviors on Windows
Browse files Browse the repository at this point in the history
Fixes #509
  • Loading branch information
panjf2000 committed Oct 19, 2023
1 parent bcc291f commit a6a1878
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 11 deletions.
10 changes: 1 addition & 9 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
---
name: Pull request
about: Propose changes to the code
title: ''
labels: ''
assignees: ''
---

<!--
Thank you for contributing to `gnet`! Please fill this out to help us make the most of your pull request.
Thank you for contributing to `gnet`! Please fill this out to help us review your pull request more efficiently.
Was this change discussed in an issue first? That can help save time in case the change is not a good fit for the project. Not all pull requests get merged.
Expand Down
106 changes: 105 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package gnet

import (
"io"
"math/rand"
"net"
"sync"
Expand Down Expand Up @@ -60,7 +61,8 @@ func (ev *clientEvents) OnTraffic(c Conn) (action Action) {
} else { // UDP
ev.packetLen = 1024
}
buf, _ := c.Next(-1)
buf, err := c.Next(-1)
assert.NoError(ev.tester, err)
p = append(p, buf...)
if len(p) < ev.packetLen {
c.SetContext(p)
Expand Down Expand Up @@ -338,6 +340,8 @@ func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr
}
require.NoError(t, err)
defer c.Close()
err = c.Wake(nil)
require.NoError(t, err)
var rspCh chan []byte
if network == "udp" {
rspCh = make(chan []byte, 1)
Expand Down Expand Up @@ -387,3 +391,103 @@ func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr
}
}
}

type clientEventsForWake struct {
BuiltinEventEngine
tester *testing.T
ch chan struct{}
}

func (ev *clientEventsForWake) OnBoot(_ Engine) Action {
ev.ch = make(chan struct{})
return None
}

func (ev *clientEventsForWake) OnTraffic(c Conn) (action Action) {
n, err := c.Read(nil)
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
buf := make([]byte, 10)
n, err = c.Read(buf)
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n)
assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err)
buf, err = c.Next(10)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err)
buf, err = c.Next(-1)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
buf, err = c.Peek(10)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.ErrorIsf(ev.tester, err, io.ErrShortBuffer, "expected error: %v, but got: %v", io.ErrShortBuffer, err)
buf, err = c.Peek(-1)
assert.Nilf(ev.tester, buf, "expected: %v, but got: %v", nil, buf)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
n, err = c.Discard(10)
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
n, err = c.Discard(-1)
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, n)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
m, err := c.WriteTo(io.Discard)
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, m)
assert.NoErrorf(ev.tester, err, "expected: %v, but got: %v", nil, err)
n = c.InboundBuffered()
assert.Zerof(ev.tester, n, "expected: %v, but got: %v", 0, m)
<-ev.ch
return None
}

type serverEventsForWake struct {
BuiltinEventEngine
network, addr string
client *Client
clientEV *clientEventsForWake
tester *testing.T
clients int32
started int32
}

func (ev *serverEventsForWake) OnOpen(_ Conn) ([]byte, Action) {
atomic.AddInt32(&ev.clients, 1)
return nil, None
}

func (ev *serverEventsForWake) OnClose(_ Conn, _ error) Action {
if atomic.AddInt32(&ev.clients, -1) == 0 {
return Shutdown
}
return None
}

func (ev *serverEventsForWake) OnTick() (time.Duration, Action) {
if atomic.CompareAndSwapInt32(&ev.started, 0, 1) {
go testConnWakeImmediately(ev.tester, ev.client, ev.clientEV, ev.network, ev.addr)
}
return 100 * time.Millisecond, None
}

func testConnWakeImmediately(t *testing.T, client *Client, clientEV *clientEventsForWake, network, addr string) {
c, err := client.Dial(network, addr)
assert.NoErrorf(t, err, "failed to dial: %v", err)
err = c.Wake(nil)
assert.NoError(t, err)
err = c.Close()
assert.NoError(t, err)
clientEV.ch <- struct{}{}
}

func TestWakeConnImmediately(t *testing.T) {
clientEV := &clientEventsForWake{tester: t}
client, err := NewClient(clientEV, WithLogLevel(logging.DebugLevel))
assert.NoError(t, err)

err = client.Start()
assert.NoError(t, err)
defer client.Stop() //nolint:errcheck

serverEV := &serverEventsForWake{tester: t, network: "tcp", addr: ":18888", client: client, clientEV: clientEV}

err = Run(serverEV, serverEV.network+"://"+serverEV.addr, WithTicker(true))
assert.NoError(t, err)
}
2 changes: 1 addition & 1 deletion connection_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func (c *conn) Read(p []byte) (n int, err error) {
n = copy(p, c.buffer)
c.buffer = c.buffer[n:]
if n == 0 && len(p) > 0 {
err = io.EOF
err = io.ErrShortBuffer
}
return
}
Expand Down
32 changes: 32 additions & 0 deletions connection_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ func (c *conn) resetBuffer() {
}

func (c *conn) Read(p []byte) (n int, err error) {
if c.buffer == nil {
if len(p) == 0 {
return 0, nil
}
return 0, io.ErrShortBuffer
}

if c.inboundBuffer.IsEmpty() {
n = copy(p, c.buffer.B)
c.buffer.B = c.buffer.B[n:]
Expand All @@ -130,6 +137,13 @@ func (c *conn) Read(p []byte) (n int, err error) {
}

func (c *conn) Next(n int) (buf []byte, err error) {
if c.buffer == nil {
if n <= 0 {
return nil, nil
}
return nil, io.ErrShortBuffer
}

inBufferLen := c.inboundBuffer.Buffered()
if totalLen := inBufferLen + c.buffer.Len(); n > totalLen {
return nil, io.ErrShortBuffer
Expand Down Expand Up @@ -160,6 +174,13 @@ func (c *conn) Next(n int) (buf []byte, err error) {
}

func (c *conn) Peek(n int) (buf []byte, err error) {
if c.buffer == nil {
if n <= 0 {
return nil, nil
}
return nil, io.ErrShortBuffer
}

inBufferLen := c.inboundBuffer.Buffered()
if totalLen := inBufferLen + c.buffer.Len(); n > totalLen {
return nil, io.ErrShortBuffer
Expand All @@ -186,6 +207,10 @@ func (c *conn) Peek(n int) (buf []byte, err error) {
}

func (c *conn) Discard(n int) (int, error) {
if c.buffer == nil {
return 0, nil
}

inBufferLen := c.inboundBuffer.Buffered()
tempBufferLen := c.buffer.Len()
if inBufferLen+tempBufferLen < n || n <= 0 {
Expand Down Expand Up @@ -242,6 +267,10 @@ func (c *conn) WriteTo(w io.Writer) (n int64, err error) {
return
}
}

if c.buffer == nil {
return 0, nil
}
defer c.buffer.Reset()
return c.buffer.WriteTo(w)
}
Expand All @@ -251,6 +280,9 @@ func (c *conn) Flush() error {
}

func (c *conn) InboundBuffered() int {
if c.buffer == nil {
return 0
}
return c.inboundBuffer.Buffered() + c.buffer.Len()
}

Expand Down

0 comments on commit a6a1878

Please sign in to comment.