@@ -30,24 +30,30 @@ const (
3030
3131 chainPOSTROUTING = "POSTROUTING"
3232 chainPREROUTING = "PREROUTING"
33+ chainFORWARD = "FORWARD"
3334 chainRTNAT = "NETBIRD-RT-NAT"
3435 chainRTFWDIN = "NETBIRD-RT-FWD-IN"
3536 chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
3637 chainRTPRE = "NETBIRD-RT-PRE"
3738 chainRTRDR = "NETBIRD-RT-RDR"
39+ chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
3840 routingFinalForwardJump = "ACCEPT"
3941 routingFinalNatJump = "MASQUERADE"
4042
4143 jumpManglePre = "jump-mangle-pre"
4244 jumpNatPre = "jump-nat-pre"
4345 jumpNatPost = "jump-nat-post"
46+ jumpMSSClamp = "jump-mss-clamp"
4447 markManglePre = "mark-mangle-pre"
4548 markManglePost = "mark-mangle-post"
4649 matchSet = "--match-set"
4750
4851 dnatSuffix = "_dnat"
4952 snatSuffix = "_snat"
5053 fwdSuffix = "_fwd"
54+
55+ // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
56+ ipTCPHeaderMinSize = 40
5157)
5258
5359type ruleInfo struct {
@@ -77,16 +83,18 @@ type router struct {
7783 ipsetCounter * ipsetCounter
7884 wgIface iFaceMapper
7985 legacyManagement bool
86+ mtu uint16
8087
8188 stateManager * statemanager.Manager
8289 ipFwdState * ipfwdstate.IPForwardingState
8390}
8491
85- func newRouter (iptablesClient * iptables.IPTables , wgIface iFaceMapper ) (* router , error ) {
92+ func newRouter (iptablesClient * iptables.IPTables , wgIface iFaceMapper , mtu uint16 ) (* router , error ) {
8693 r := & router {
8794 iptablesClient : iptablesClient ,
8895 rules : make (map [string ][]string ),
8996 wgIface : wgIface ,
97+ mtu : mtu ,
9098 ipFwdState : ipfwdstate .NewIPForwardingState (),
9199 }
92100
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
392400 {chainRTPRE , tableMangle },
393401 {chainRTNAT , tableNat },
394402 {chainRTRDR , tableNat },
403+ {chainRTMSSCLAMP , tableMangle },
395404 } {
396405 ok , err := r .iptablesClient .ChainExists (chainInfo .table , chainInfo .chain )
397406 if err != nil {
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
416425 {chainRTPRE , tableMangle },
417426 {chainRTNAT , tableNat },
418427 {chainRTRDR , tableNat },
428+ {chainRTMSSCLAMP , tableMangle },
419429 } {
420430 if err := r .iptablesClient .NewChain (chainInfo .table , chainInfo .chain ); err != nil {
421431 return fmt .Errorf ("create chain %s in table %s: %w" , chainInfo .chain , chainInfo .table , err )
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
438448 return fmt .Errorf ("add jump rules: %w" , err )
439449 }
440450
451+ if err := r .addMSSClampingRules (); err != nil {
452+ log .Errorf ("failed to add MSS clamping rules: %s" , err )
453+ }
454+
441455 return nil
442456}
443457
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
518532 return nil
519533}
520534
535+ // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
536+ // TODO: Add IPv6 support
537+ func (r * router ) addMSSClampingRules () error {
538+ mss := r .mtu - ipTCPHeaderMinSize
539+
540+ // Add jump rule from FORWARD chain in mangle table to our custom chain
541+ jumpRule := []string {
542+ "-j" , chainRTMSSCLAMP ,
543+ }
544+ if err := r .iptablesClient .Insert (tableMangle , chainFORWARD , 1 , jumpRule ... ); err != nil {
545+ return fmt .Errorf ("add jump to MSS clamp chain: %w" , err )
546+ }
547+ r .rules [jumpMSSClamp ] = jumpRule
548+
549+ ruleOut := []string {
550+ "-o" , r .wgIface .Name (),
551+ "-p" , "tcp" ,
552+ "--tcp-flags" , "SYN,RST" , "SYN" ,
553+ "-j" , "TCPMSS" ,
554+ "--set-mss" , fmt .Sprintf ("%d" , mss ),
555+ }
556+ if err := r .iptablesClient .Append (tableMangle , chainRTMSSCLAMP , ruleOut ... ); err != nil {
557+ return fmt .Errorf ("add outbound MSS clamp rule: %w" , err )
558+ }
559+ r .rules ["mss-clamp-out" ] = ruleOut
560+
561+ return nil
562+ }
563+
521564func (r * router ) insertEstablishedRule (chain string ) error {
522565 establishedRule := getConntrackEstablished ()
523566
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
558601}
559602
560603func (r * router ) cleanJumpRules () error {
561- for _ , ruleKey := range []string {jumpNatPost , jumpManglePre , jumpNatPre } {
604+ for _ , ruleKey := range []string {jumpNatPost , jumpManglePre , jumpNatPre , jumpMSSClamp } {
562605 if rule , exists := r .rules [ruleKey ]; exists {
563606 var table , chain string
564607 switch ruleKey {
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
571614 case jumpNatPre :
572615 table = tableNat
573616 chain = chainPREROUTING
617+ case jumpMSSClamp :
618+ table = tableMangle
619+ chain = chainFORWARD
574620 default :
575621 return fmt .Errorf ("unknown jump rule: %s" , ruleKey )
576622 }
0 commit comments