Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature support proxy protocol #97

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
if m.readTimeout > noTimeout {
_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
}
if err := muc.checkPrefix(); err != nil {
_ = c.Close()
return
}
for _, sl := range m.sls {
for _, s := range sl.ss {
matched := s(muc.Conn, muc.startSniffing())
Expand Down Expand Up @@ -273,7 +277,9 @@ func (l muxListener) Accept() (net.Conn, error) {
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
type MuxConn struct {
net.Conn
buf bufferedReader
buf bufferedReader
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
}

func newMuxConn(c net.Conn) *MuxConn {
Expand Down
115 changes: 115 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package cmux

import (
"bufio"
"bytes"
"fmt"
"net"
"strconv"
"strings"
)

const (
defaultBufSize = 1024
)

var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
)

func (m *MuxConn) checkPrefix() error {
buf := make([]byte, defaultBufSize)
n, err := m.Read(buf)

reader := bufio.NewReader(bytes.NewReader(buf[:n]))

// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := reader.Peek(i)
if err != nil {
return err
}

// Check for a prefix mismatch, quit early
if !bytes.Equal(inp, prefix[:i]) {
m.buf.buffer.Write(buf[:n])
m.doneSniffing()
return nil
}
}

// Read the header line
headerLine, err := reader.ReadString('\n')
if err != nil {
return err
}

// Strip the carriage return and new line
header := headerLine[:len(headerLine)-2]

// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) < 2 {
return fmt.Errorf("invalid header line: %s", header)
}

// Verify the type is known
switch parts[1] {
case "UNKNOWN":
return nil
case "TCP4":
case "TCP6":
default:
return fmt.Errorf("unhandled address type: %s", parts[1])
}

if len(parts) != 6 {
return fmt.Errorf("invalid header line: %s", header)
}

// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
return fmt.Errorf("invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
return fmt.Errorf("invalid source port: %s", parts[4])
}
m.srcAddr = &net.TCPAddr{IP: ip, Port: port}

// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
return fmt.Errorf("invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
return fmt.Errorf("invalid destination port: %s", parts[5])
}
m.dstAddr = &net.TCPAddr{IP: ip, Port: port}

if n != len(headerLine) {
m.buf.buffer.Write(buf[len(headerLine):n])
m.doneSniffing()
}

return nil
}

func (m *MuxConn) RemoteAddr() net.Addr {
if m.srcAddr != nil {
return m.srcAddr
}
return m.Conn.RemoteAddr()
}

func (m *MuxConn) LocalAddr() net.Addr {
if m.dstAddr != nil {
return m.dstAddr
}
return m.Conn.LocalAddr()
}
57 changes: 57 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package cmux

import (
"net"
"testing"
)

func TestMuxConn_CheckPrefix(t *testing.T) {
// Create a listener on a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}

go func() {
// Accept a connection from the listener
conn, err := listener.Accept()
if err != nil {
t.Errorf("failed to accept connection: %v", err)
return
}

// Write a PROXY header to the connection
_, err = conn.Write([]byte("PROXY TCP4 192.168.1.1 192.168.1.2 1234 5678\r\n"))
if err != nil {
t.Errorf("failed to write PROXY header: %v", err)
return
}

// Close the connection
conn.Close()
}()

// Dial the listener with a MuxConn
conn, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
t.Fatalf("failed to dial listener: %v", err)
}

muxConn := newMuxConn(conn)

// Call checkPrefix to parse the PROXY header
err = muxConn.checkPrefix()
if err != nil {
t.Errorf("checkPrefix returned error: %v", err)
}

// Verify the source and destination addresses were parsed correctly
expectedSrc := "192.168.1.1:1234"
expectedDst := "192.168.1.2:5678"
if muxConn.RemoteAddr().String() != expectedSrc {
t.Errorf("RemoteAddr() returned %s, expected %s", muxConn.RemoteAddr().String(), expectedSrc)
}
if muxConn.LocalAddr().String() != expectedDst {
t.Errorf("LocalAddr() returned %s, expected %s", muxConn.LocalAddr().String(), expectedDst)
}
}