Skip to content

Commit

Permalink
Merge pull request #43 from multiformats/feat/expose-half-close
Browse files Browse the repository at this point in the history
expose methods from underlying connection types
  • Loading branch information
Stebalien authored Jun 22, 2018
2 parents 8792ba0 + a109f8d commit 31e031d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 24 deletions.
88 changes: 64 additions & 24 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,64 @@ type Conn interface {
RemoteMultiaddr() ma.Multiaddr
}

// WrapNetConn wraps a net.Conn object with a Multiaddr
// friendly Conn.
type halfOpen interface {
net.Conn
CloseRead() error
CloseWrite() error
}

func wrap(nconn net.Conn, laddr, raddr ma.Multiaddr) Conn {
endpts := maEndpoints{
laddr: laddr,
raddr: raddr,
}
// This sucks. However, it's the only way to reliably expose the
// underlying methods. This way, users that need access to, e.g.,
// CloseRead and CloseWrite, can do so via type assertions.
switch nconn := nconn.(type) {
case *net.TCPConn:
return &struct {
*net.TCPConn
maEndpoints
}{nconn, endpts}
case *net.UDPConn:
return &struct {
*net.UDPConn
maEndpoints
}{nconn, endpts}
case *net.IPConn:
return &struct {
*net.IPConn
maEndpoints
}{nconn, endpts}
case *net.UnixConn:
return &struct {
*net.UnixConn
maEndpoints
}{nconn, endpts}
case halfOpen:
return &struct {
halfOpen
maEndpoints
}{nconn, endpts}
default:
return &struct {
net.Conn
maEndpoints
}{nconn, endpts}
}
}

// WrapNetConn wraps a net.Conn object with a Multiaddr friendly Conn.
//
// This function does it's best to avoid "hiding" methods exposed by the wrapped
// type. Guarantees:
//
// * If the wrapped connection exposes the "half-open" closer methods
// (CloseWrite, CloseRead), these will be available on the wrapped connection
// via type assertions.
// * If the wrapped connection is a UnixConn, IPConn, TCPConn, or UDPConn, all
// methods on these wrapped connections will be available via type assertions.
func WrapNetConn(nconn net.Conn) (Conn, error) {
if nconn == nil {
return nil, fmt.Errorf("failed to convert nconn.LocalAddr: nil")
Expand All @@ -45,30 +101,23 @@ func WrapNetConn(nconn net.Conn) (Conn, error) {
return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
}

return &maConn{
Conn: nconn,
laddr: laddr,
raddr: raddr,
}, nil
return wrap(nconn, laddr, raddr), nil
}

// maConn implements the Conn interface. It's a thin wrapper
// around a net.Conn
type maConn struct {
net.Conn
type maEndpoints struct {
laddr ma.Multiaddr
raddr ma.Multiaddr
}

// LocalMultiaddr returns the local address associated with
// this connection
func (c *maConn) LocalMultiaddr() ma.Multiaddr {
func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr {
return c.laddr
}

// RemoteMultiaddr returns the remote address associated with
// this connection
func (c *maConn) RemoteMultiaddr() ma.Multiaddr {
func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
return c.raddr
}

Expand Down Expand Up @@ -135,12 +184,7 @@ func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, er
return nil, err
}
}

return &maConn{
Conn: nconn,
laddr: local,
raddr: remote,
}, nil
return wrap(nconn, local, remote), nil
}

// Dial connects to a remote address. It uses an underlying net.Conn,
Expand Down Expand Up @@ -204,11 +248,7 @@ func (l *maListener) Accept() (Conn, error) {
return nil, fmt.Errorf("failed to convert connn.RemoteAddr: %s", err)
}

return &maConn{
Conn: nconn,
laddr: l.laddr,
raddr: raddr,
}, nil
return wrap(nconn, l.laddr, raddr), nil
}

// Multiaddr returns the listener's (local) Multiaddr.
Expand Down
4 changes: 4 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,14 @@ func TestWrapNetConn(t *testing.T) {
defer wg.Done()
cB, err := listener.Accept()
checkErr(err, "failed to accept")
_ = cB.(halfOpen)
cB.Close()
}()

cA, err := net.Dial("tcp", listener.Addr().String())
checkErr(err, "failed to dial")
defer cA.Close()
_ = cA.(halfOpen)

lmaddr, err := FromNetAddr(cA.LocalAddr())
checkErr(err, "failed to get local addr")
Expand All @@ -422,6 +424,8 @@ func TestWrapNetConn(t *testing.T) {
mcA, err := WrapNetConn(cA)
checkErr(err, "failed to wrap conn")

_ = mcA.(halfOpen)

if mcA.LocalAddr().String() != cA.LocalAddr().String() {
t.Error("wrapped conn local addr differs")
}
Expand Down

0 comments on commit 31e031d

Please sign in to comment.