Skip to content

Commit

Permalink
firewall reject packets: cleanup error cases (#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
wadey authored Nov 13, 2023
1 parent 3356e03 commit fe16ea5
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 14 deletions.
26 changes: 20 additions & 6 deletions inside.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
}

out = iputil.CreateRejectPacket(packet, out)
if len(out) == 0 {
return
}

_, err := f.readers[q].Write(out)
if err != nil {
f.l.WithError(err).Error("Failed to write to tun")
Expand All @@ -94,12 +98,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
return
}

// Use some out buffer space to build the packet before encryption
// Need 40 bytes for the reject packet (20 byte ipv4 header, 20 byte tcp rst packet)
// Leave 100 bytes for the encrypted packet (60 byte Nebula header, 40 byte reject packet)
out = out[:140]
outPacket := iputil.CreateRejectPacket(packet, out[100:])
f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, outPacket, nb, out, q)
out = iputil.CreateRejectPacket(packet, out)
if len(out) == 0 {
return
}

if len(out) > iputil.MaxRejectPacketSize {
if f.l.GetLevel() >= logrus.InfoLevel {
f.l.
WithField("packet", packet).
WithField("outPacket", out).
Info("rejectOutside: packet too big, not sending")
}
return
}

f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
}

func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
Expand Down
41 changes: 34 additions & 7 deletions iputil/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,19 @@ import (
"golang.org/x/net/ipv4"
)

const (
// Need 96 bytes for the largest reject packet:
// - 20 byte ipv4 header
// - 8 byte icmpv4 header
// - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header)
MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8
)

func CreateRejectPacket(packet []byte, out []byte) []byte {
// TODO ipv4 only, need to fix when inside supports ipv6
if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version {
return nil
}

switch packet[9] {
case 6: // tcp
return ipv4CreateRejectTCPPacket(packet, out)
Expand All @@ -19,20 +30,28 @@ func CreateRejectPacket(packet []byte, out []byte) []byte {
func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte {
ihl := int(packet[0]&0x0f) << 2

// ICMP reply includes header and first 8 bytes of the packet
if len(packet) < ihl {
// We need at least this many bytes for this to be a valid packet
return nil
}

// ICMP reply includes original header and first 8 bytes of the packet
packetLen := len(packet)
if packetLen > ihl+8 {
packetLen = ihl + 8
}

outLen := ipv4.HeaderLen + 8 + packetLen
if outLen > cap(out) {
return nil
}

out = out[:(outLen)]
out = out[:outLen]

ipHdr := out[0:ipv4.HeaderLen]
ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl
ipHdr[1] = 0 // DSCP, ECN
binary.BigEndian.PutUint16(ipHdr[2:], uint16(ipv4.HeaderLen+8+packetLen)) // Total Length
ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl
ipHdr[1] = 0 // DSCP, ECN
binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length

ipHdr[4] = 0 // id
ipHdr[5] = 0 // .
Expand Down Expand Up @@ -76,7 +95,15 @@ func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte {
ihl := int(packet[0]&0x0f) << 2
outLen := ipv4.HeaderLen + tcpLen

out = out[:(outLen)]
if len(packet) < ihl+tcpLen {
// We need at least this many bytes for this to be a valid packet
return nil
}
if outLen > cap(out) {
return nil
}

out = out[:outLen]

ipHdr := out[0:ipv4.HeaderLen]
ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl
Expand Down
73 changes: 73 additions & 0 deletions iputil/packet_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package iputil

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
)

func Test_CreateRejectPacket(t *testing.T) {
h := ipv4.Header{
Len: 20,
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Protocol: 1, // ICMP
}

b, err := h.Marshal()
if err != nil {
t.Fatalf("h.Marhshal: %v", err)
}
b = append(b, []byte{0, 3, 0, 4}...)

expectedLen := ipv4.HeaderLen + 8 + h.Len + 4
out := make([]byte, expectedLen)
rejectPacket := CreateRejectPacket(b, out)
assert.NotNil(t, rejectPacket)
assert.Len(t, rejectPacket, expectedLen)

// ICMP with max header len
h = ipv4.Header{
Len: 60,
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Protocol: 1, // ICMP
Options: make([]byte, 40),
}

b, err = h.Marshal()
if err != nil {
t.Fatalf("h.Marhshal: %v", err)
}
b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...)

expectedLen = MaxRejectPacketSize
out = make([]byte, MaxRejectPacketSize)
rejectPacket = CreateRejectPacket(b, out)
assert.NotNil(t, rejectPacket)
assert.Len(t, rejectPacket, expectedLen)

// TCP with max header len
h = ipv4.Header{
Len: 60,
Src: net.IPv4(10, 0, 0, 1),
Dst: net.IPv4(10, 0, 0, 2),
Protocol: 6, // TCP
Options: make([]byte, 40),
}

b, err = h.Marshal()
if err != nil {
t.Fatalf("h.Marhshal: %v", err)
}
b = append(b, []byte{0, 3, 0, 4}...)
b = append(b, make([]byte, 16)...)

expectedLen = ipv4.HeaderLen + 20
out = make([]byte, expectedLen)
rejectPacket = CreateRejectPacket(b, out)
assert.NotNil(t, rejectPacket)
assert.Len(t, rejectPacket, expectedLen)
}
4 changes: 3 additions & 1 deletion outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out

dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Expand Down

0 comments on commit fe16ea5

Please sign in to comment.