Skip to content

Commit

Permalink
Cert v2 + tun changes for Linux (#1224)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian authored Sep 21, 2024
1 parent bf79947 commit 28cd257
Show file tree
Hide file tree
Showing 19 changed files with 396 additions and 207 deletions.
15 changes: 10 additions & 5 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type CachedCertificate struct {
func UnmarshalCertificate(b []byte) (Certificate, error) {
//TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers?
var c Certificate
c, err := unmarshalCertificateV2(b, nil)
c, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
if err == nil {
return c, nil
}
Expand All @@ -129,15 +129,15 @@ func UnmarshalCertificate(b []byte) (Certificate, error) {
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind.
func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (Certificate, error) {
func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate
var err error

switch v {
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
c, err = unmarshalCertificateV2(b, publicKey)
c, err = unmarshalCertificateV2(b, publicKey, curve)
default:
//TODO: make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
Expand All @@ -146,10 +146,15 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C
if err != nil {
return nil, err
}

if c.Curve() != curve {
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
}

return c, nil
}

func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAPool) (*CachedCertificate, error) {
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
if publicKey == nil {
return nil, ErrNoPeerStaticKey
}
Expand All @@ -158,7 +163,7 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAP
return nil, ErrNoPayload
}

c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey)
c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
Expand Down
4 changes: 1 addition & 3 deletions cert/cert_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"

"golang.org/x/crypto/curve25519"
Expand Down Expand Up @@ -393,8 +392,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error)
}
}

slices.SortFunc(nc.details.networks, comparePrefix)
slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
//do not sort the subnets field for V1 certs

return &nc, nil
}
Expand Down
7 changes: 4 additions & 3 deletions cert/cert_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (d *detailsV2) Marshal() ([]byte, error) {
return b.Bytes()
}

func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) {
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
Expand All @@ -473,11 +473,12 @@ func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error)
return nil, ErrBadFormat
}

//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(Curve_CURVE25519)) {
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve := Curve(rawCurve)
curve = Curve(rawCurve)

// Maybe grab the public key
var rawPublicKey cryptobyte.String
Expand Down
2 changes: 1 addition & 1 deletion cert/pem.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil)
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
Expand Down
4 changes: 4 additions & 0 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(),
})
}

func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}
9 changes: 9 additions & 0 deletions dns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeAAAA:
l.Debugf("Query for AAAA %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b, err := netip.ParseAddr(a)
Expand Down
4 changes: 2 additions & 2 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
Expand Down Expand Up @@ -404,7 +404,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
return true
}

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
Expand Down
3 changes: 3 additions & 0 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd
useVersion = 2
}

//todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
//todo why do we care about the vpnip in the packet? We know where it came from, right?

if detailsVpnIp != vpnAddrs[0] {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
Expand Down
111 changes: 101 additions & 10 deletions outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"net/netip"
"time"

"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -297,22 +300,112 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h

// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
if len(data) < 1 {
return errors.New("packet too short")
}

version := int((data[0] >> 4) & 0x0f)
switch version {
case ipv4.Version:
return parseV4(data, incoming, fp)
case ipv6.Version:
return parseV6(data, incoming, fp)
}
return fmt.Errorf("packet is an unknown ip version: %v", version)
}

func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
}

if incoming {
fp.RemoteIP, _ = netip.AddrFromSlice(data[8:24])
fp.LocalIP, _ = netip.AddrFromSlice(data[24:40])
} else {
fp.LocalIP, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteIP, _ = netip.AddrFromSlice(data[24:40])
}

//TODO: whats a reasonable number of extension headers to attempt to parse?
//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
protoAt := 6
offset := 40
for i := 0; i < 24; i++ {
if dataLen < offset {
break
}

proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6:
//TODO: we need a new protocol in config language "icmpv6"
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil

// Is it an ipv4 packet?
if int((data[0]>>4)&0x0f) != 4 {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
case layers.IPProtocolTCP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
fp.Fragment = false
return nil

case layers.IPProtocolUDP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
fp.Fragment = false
return nil

case layers.IPProtocolIPv6Fragment:
//TODO: can we determine the protocol?
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = true
return nil

default:
if dataLen < offset+1 {
break
}

next := int(data[offset+1]) * 8
if next == 0 {
// each extension is at least 8 bytes
next = 8
}

protoAt = offset
offset = offset + next
}
}

return fmt.Errorf("could not find payload in ipv6 packet")
}

func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
}

// Adjust our start position based on the advertised ip header length
ihl := int(data[0]&0x0f) << 2

// Well formed ip header length?
if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl)
return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
}

// Check if this is the second or further fragment of a fragmented packet.
Expand All @@ -328,12 +421,11 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
minLen += minFwPacketLen
}
if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
}

// Firewall packets are locally oriented
if incoming {
//TODO: IPV6-WORK
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
Expand All @@ -344,7 +436,6 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
//TODO: IPV6-WORK
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
Expand Down
18 changes: 12 additions & 6 deletions outside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ import (
func Test_newPacket(t *testing.T) {
p := &firewall.Packet{}

// length fail
err := newPacket([]byte{0, 1}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes")
// length fails
err := newPacket([]byte{}, true, p)
assert.EqualError(t, err, "packet too short")

err = newPacket([]byte{0x40}, true, p)
assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")

err = newPacket([]byte{0x60}, true, p)
assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")

// length fail with ip options
h := ipv4.Header{
Expand All @@ -29,15 +35,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal()
err = newPacket(b, true, p)

assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")

// not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0")
assert.EqualError(t, err, "packet is an unknown ip version: 0")

// invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8")
assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")

// account for variable ip header length - incoming
h = ipv4.Header{
Expand Down
Loading

0 comments on commit 28cd257

Please sign in to comment.