Skip to content

Commit

Permalink
fix: incidental packet drop and weird UDP state maintaining (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzz2017 authored Jun 16, 2024
1 parent ed50de2 commit 93e47ff
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 42 deletions.
17 changes: 14 additions & 3 deletions control/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,15 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
copy(newBuf, buf[:n])
newOob := pool.Get(oobn)
copy(newOob, oob[:oobn])
go func(data pool.PB, oob pool.PB, src netip.AddrPort) {
newSrc := src
convergeSrc := common.ConvergeAddrPort(src)
// Debug:
// t := time.Now()
DefaultUdpTaskPool.EmitTask(convergeSrc.String(), func() {
data := newBuf
oob := newOob
src := newSrc

defer data.Put()
defer oob.Put()
var realDst netip.AddrPort
Expand All @@ -777,10 +785,13 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
} else {
realDst = pktDst
}
if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil {
if e := c.handlePkt(udpConn, data, convergeSrc, common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil {
c.log.Warnln("handlePkt:", e)
}
}(newBuf, newOob, src)
})
// if d := time.Since(t); d > 100*time.Millisecond {
// logrus.Println(d)
// }
}
}()
c.ActivateCheck()
Expand Down
36 changes: 18 additions & 18 deletions control/kern/tproxy.c
Original file line number Diff line number Diff line change
Expand Up @@ -1290,24 +1290,19 @@ refresh_udp_conn_state_timer(struct tuples_key *key, bool is_egress)
if (unlikely(!value))
return NULL;

ret = bpf_timer_init(&value->timer, &udp_conn_state_map,
CLOCK_MONOTONIC);
if (unlikely(ret))
goto del;
if ((ret = bpf_timer_init(&value->timer, &udp_conn_state_map,
CLOCK_MONOTONIC)))
goto retn;

ret = bpf_timer_set_callback(&value->timer,
refresh_udp_conn_state_timer_cb);
if (unlikely(ret))
goto del;
if ((ret = bpf_timer_set_callback(&value->timer,
refresh_udp_conn_state_timer_cb)))
goto retn;

ret = bpf_timer_start(&value->timer, TIMEOUT_UDP_CONN_STATE, 0);
if (unlikely(ret))
goto del;
if ((ret = bpf_timer_start(&value->timer, TIMEOUT_UDP_CONN_STATE, 0)))
goto retn;

retn:
return value;
del:
bpf_map_delete_elem(&udp_conn_state_map, key);
return NULL;
}

SEC("tc/wan_ingress")
Expand Down Expand Up @@ -1515,17 +1510,22 @@ int tproxy_wan_egress(struct __sk_buff *skb)
flag[6] = tuples.dscp;
struct pid_pname *pid_pname;

if (pid_is_control_plane(skb, &pid_pname)) {
// from control plane
// => direct.
return TC_ACT_OK;
}

struct udp_conn_state *conn_state =
refresh_udp_conn_state_timer(&tuples.five, true);
if (!conn_state)
return TC_ACT_SHOT;
if (!conn_state->is_egress ||
pid_is_control_plane(skb, &pid_pname)) {
// Input udp connection or
// from control plane
if (!conn_state->is_egress) {
// Input udp connection
// => direct.
return TC_ACT_OK;
}

if (pid_pname) {
// 2, 3, 4, 5
__builtin_memcpy(&flag[2], pid_pname->pname,
Expand Down
2 changes: 1 addition & 1 deletion control/packet_sniffer_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type PacketSnifferKey struct {
RAddr netip.AddrPort
}

var DefaultPacketSnifferPool = NewPacketSnifferPool()
var DefaultPacketSnifferSessionMgr = NewPacketSnifferPool()

func NewPacketSnifferPool() *PacketSnifferPool {
return &PacketSnifferPool{}
Expand Down
4 changes: 2 additions & 2 deletions control/packet_sniffer_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ var testPacketSnifferData = []string{
func TestPacketSniffer_Normal(t *testing.T) {
for _, _data := range testPacketSnifferData {
data, _ := hex.DecodeString(_data)
sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(PacketSnifferKey{
sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(PacketSnifferKey{
LAddr: netip.MustParseAddrPort("1.1.1.1:1111"),
RAddr: netip.MustParseAddrPort("2.2.2.2:2222"),
}, nil)
Expand All @@ -44,7 +44,7 @@ func TestPacketSniffer_Mismatched(t *testing.T) {
dst := netip.MustParseAddrPort("2.2.2.2:2222")
for _, _data := range testPacketSnifferData {
data, _ := hex.DecodeString(_data)
sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(PacketSnifferKey{
sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(PacketSnifferKey{
LAddr: netip.MustParseAddrPort("1.1.1.1:1111"),
RAddr: dst,
}, nil)
Expand Down
67 changes: 53 additions & 14 deletions control/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"net"
"net/netip"

"time"

"github.com/daeuniverse/dae/common"
Expand All @@ -29,10 +30,11 @@ const (
)

type DialOption struct {
Target string
Dialer *dialer.Dialer
Outbound *ob.DialerGroup
Network string
Target string
Dialer *dialer.Dialer
Outbound *ob.DialerGroup
Network string
SniffedDomain string
}

func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) {
Expand Down Expand Up @@ -60,21 +62,50 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
var realSrc netip.AddrPort
var domain string
realSrc = src
ue, ueExists := DefaultUdpEndpointPool.Get(realSrc)
if ueExists && ue.SniffedDomain != "" {
// It is quic ...
// Fast path.
domain := ue.SniffedDomain
dialTarget := realDst.String()

if c.log.IsLevelEnabled(logrus.TraceLevel) {
fields := logrus.Fields{
"network": "udp(fp)",
"outbound": ue.Outbound.Name,
"policy": ue.Outbound.GetSelectionPolicy(),
"dialer": ue.Dialer.Property().Name,
"sniffed": domain,
"ip": RefineAddrPortToShow(realDst),
"pid": routingResult.Pid,
"dscp": routingResult.Dscp,
"pname": ProcessName2String(routingResult.Pname[:]),
"mac": Mac2String(routingResult.Mac[:]),
}
c.log.WithFields(fields).Tracef("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget)
}

_, err = ue.WriteTo(data, dialTarget)
if err != nil {
return err
}
return nil
}

// To keep consistency with kernel program, we only sniff DNS request sent to 53.
dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53)
// We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time.
isDns := dnsMessage != nil
if !isDns && !skipSniffing && !DefaultUdpEndpointPool.Exists(realSrc) {
if !isDns && !skipSniffing && !ueExists {
// Sniff Quic, ...
key := PacketSnifferKey{
LAddr: realSrc,
RAddr: realDst,
}
_sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(key, nil)
_sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(key, nil)
_sniffer.Mu.Lock()
// Re-get sniffer from pool to confirm the transaction is not done.
sniffer := DefaultPacketSnifferPool.Get(key)
sniffer := DefaultPacketSnifferSessionMgr.Get(key)
if _sniffer == sniffer {
sniffer.AppendData(data)
domain, err = sniffer.SniffUdp()
Expand All @@ -92,7 +123,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
WithField("to", realDst).
Trace("sniffUdp")
}
defer DefaultPacketSnifferPool.Remove(key, sniffer)
defer DefaultPacketSnifferSessionMgr.Remove(key, sniffer)
// Re-handlePkt after self func.
toRehandle := sniffer.Data()[1 : len(sniffer.Data())-1] // Skip the first empty and the last (self).
sniffer.Mu.Unlock()
Expand Down Expand Up @@ -134,7 +165,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.

// Get udp endpoint.
var ue *UdpEndpoint
retry := 0
networkType := &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
Expand Down Expand Up @@ -217,10 +247,11 @@ getNew:
return nil, fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
}
return &DialOption{
Target: dialTarget,
Dialer: dialerForNew,
Outbound: outbound,
Network: common.MagicNetwork("udp", routingResult.Mark),
Target: dialTarget,
Dialer: dialerForNew,
Outbound: outbound,
Network: common.MagicNetwork("udp", routingResult.Mark),
SniffedDomain: domain,
}, nil
},
})
Expand All @@ -243,6 +274,10 @@ getNew:
retry++
goto getNew
}
if domain == "" {
// It is used for showing.
domain = ue.SniffedDomain
}

_, err = ue.WriteTo(data, dialTarget)
if err != nil {
Expand Down Expand Up @@ -280,7 +315,11 @@ getNew:
"pname": ProcessName2String(routingResult.Pname[:]),
"mac": Mac2String(routingResult.Mac[:]),
}
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget)
logger := c.log.WithFields(fields).Infof
if !isNew && c.log.IsLevelEnabled(logrus.DebugLevel) {
logger = c.log.WithFields(fields).Debugf
}
logger("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget)
}

return nil
Expand Down
15 changes: 12 additions & 3 deletions control/udp_endpoint_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ type UdpEndpoint struct {

Dialer *dialer.Dialer
Outbound *outbound.DialerGroup

// Non-empty indicates this UDP Endpoint is related with a sniffed domain.
SniffedDomain string
DialTarget string
}

func (ue *UdpEndpoint) start() {
Expand Down Expand Up @@ -95,9 +99,12 @@ func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint)
return nil
}

func (p *UdpEndpointPool) Exists(lAddr netip.AddrPort) (ok bool) {
_, ok = p.pool.Load(lAddr)
return ok
func (p *UdpEndpointPool) Get(lAddr netip.AddrPort) (udpEndpoint *UdpEndpoint, ok bool) {
_ue, ok := p.pool.Load(lAddr)
if !ok {
return nil, ok
}
return _ue.(*UdpEndpoint), ok
}

func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) {
Expand Down Expand Up @@ -146,6 +153,8 @@ begin:
NatTimeout: createOption.NatTimeout,
Dialer: dialOption.Dialer,
Outbound: dialOption.Outbound,
SniffedDomain: dialOption.SniffedDomain,
DialTarget: dialOption.Target,
}
ue.deadlineTimer = time.AfterFunc(createOption.NatTimeout, func() {
if _ue, ok := p.pool.LoadAndDelete(lAddr); ok {
Expand Down
92 changes: 92 additions & 0 deletions control/udp_task_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package control

import (
"sync"
"time"
)

const UdpTaskQueueLength = 128

type UdpTask = func()

type UdpTaskQueue struct {
ch chan UdpTask
timer *time.Timer
agingTime time.Duration
closed chan struct{}
freed chan struct{}
}

func (q *UdpTaskQueue) Push(task UdpTask) {
q.timer.Reset(q.agingTime)
q.ch <- task
}

type UdpTaskPool struct {
queueChPool sync.Pool
// mu protects m
mu sync.Mutex
m map[string]*UdpTaskQueue
}

func NewUdpTaskPool() *UdpTaskPool {
p := &UdpTaskPool{
queueChPool: sync.Pool{New: func() any {
return make(chan UdpTask, UdpTaskQueueLength)
}},
mu: sync.Mutex{},
m: map[string]*UdpTaskQueue{},
}
return p
}

func (p *UdpTaskPool) convoy(q *UdpTaskQueue) {
for {
select {
case <-q.closed:
clearloop:
for {
select {
case <-q.ch:
default:
break clearloop
}
}
close(q.freed)
return
case t := <-q.ch:
t()
}
}
}

func (p *UdpTaskPool) EmitTask(key string, task UdpTask) {
p.mu.Lock()
q, ok := p.m[key]
if !ok {
ch := p.queueChPool.Get().(chan UdpTask)
q = &UdpTaskQueue{
ch: ch,
timer: nil,
agingTime: DefaultNatTimeout,
closed: make(chan struct{}),
freed: make(chan struct{}),
}
q.timer = time.AfterFunc(q.agingTime, func() {
p.mu.Lock()
defer p.mu.Unlock()
if p.m[key] == q {
delete(p.m, key)
}
close(q.closed)
<-q.freed
p.queueChPool.Put(ch)
})
p.m[key] = q
go p.convoy(q)
}
p.mu.Unlock()
q.Push(task)
}

var DefaultUdpTaskPool = NewUdpTaskPool()
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ require (
github.com/gorilla/websocket v1.5.0 // indirect
github.com/klauspost/compress v1.17.4 // indirect
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
github.com/stretchr/testify v1.8.1 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.20.0 // indirect
Expand Down
Loading

0 comments on commit 93e47ff

Please sign in to comment.