Skip to content

Commit

Permalink
Iptables mode selection
Browse files Browse the repository at this point in the history
  • Loading branch information
cheina97 committed Aug 24, 2023
1 parent b9dff5a commit c70a91b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 63 deletions.
59 changes: 44 additions & 15 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ const (
ProtocolIPv6
)

// Mode to differentiate between legacy and nf_tables
type ModeType string

const (
// ModeTypeAuto is the default mode, which uses the system default
ModeTypeAuto ModeType = "auto"
// ModeTypeLegacy forces the use of the legacy iptables mode
ModeTypeLegacy ModeType = "legacy"
// ModeTypeNFTables forces the use of the nf_tables iptables mode
ModeTypeNFTables ModeType = "nf_tables"
)

type IPTables struct {
path string
proto Protocol
Expand All @@ -74,8 +86,8 @@ type IPTables struct {
v1 int
v2 int
v3 int
mode string // the underlying iptables operating mode, e.g. nf_tables
timeout int // time to wait for the iptables lock, default waits forever
mode ModeType // the underlying iptables operating mode, e.g. nf_tables
timeout int // time to wait for the iptables lock, default waits forever
}

// Stat represents a structured statistic entry.
Expand Down Expand Up @@ -106,6 +118,12 @@ func Timeout(timeout int) option {
}
}

func Mode(mode ModeType) option {
return func(ipt *IPTables) {
ipt.mode = mode
}
}

// New creates a new IPTables configured with the options passed as parameter.
// For backwards compatibility, by default always uses IPv4 and timeout 0.
// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
Expand All @@ -116,14 +134,15 @@ func New(opts ...option) (*IPTables, error) {

ipt := &IPTables{
proto: ProtocolIPv4,
mode: ModeTypeAuto,
timeout: 0,
}

for _, opt := range opts {
opt(ipt)
}

path, err := exec.LookPath(getIptablesCommand(ipt.proto))
path, err := exec.LookPath(getIptablesCommand(ipt.proto, ipt.mode))
if err != nil {
return nil, err
}
Expand All @@ -133,14 +152,13 @@ func New(opts ...option) (*IPTables, error) {
if err != nil {
return nil, fmt.Errorf("could not get iptables version: %v", err)
}
v1, v2, v3, mode, err := extractIptablesVersion(vstring)
v1, v2, v3, _, err := extractIptablesVersion(vstring)
if err != nil {
return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err)
}
ipt.v1 = v1
ipt.v2 = v2
ipt.v3 = v3
ipt.mode = mode

checkPresent, waitPresent, waitSupportSecond, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
ipt.hasCheck = checkPresent
Expand Down Expand Up @@ -518,8 +536,8 @@ func (ipt *IPTables) HasRandomFully() bool {
}

// Return version components of the underlying iptables command
func (ipt *IPTables) GetIptablesVersion() (int, int, int) {
return ipt.v1, ipt.v2, ipt.v3
func (ipt *IPTables) GetIptablesVersion() (int, int, int, ModeType) {
return ipt.v1, ipt.v2, ipt.v3, ipt.mode
}

// run runs an iptables command with the given arguments, ignoring
Expand Down Expand Up @@ -573,12 +591,23 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
}

// getIptablesCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
func getIptablesCommand(proto Protocol) string {
if proto == ProtocolIPv6 {
return "ip6tables"
} else {
return "iptables"
func getIptablesCommand(proto Protocol, mode ModeType) string {
var cmd string
switch proto {
case ProtocolIPv4:
cmd = "iptables"
case ProtocolIPv6:
cmd = "ip6tables"
}
// Append a suffix to the command to get the correct binary,
// If the mode is auto (default), the suffix is not applied and the system default is used.
switch mode {
case ModeTypeNFTables:
cmd = fmt.Sprintf("%s-nft", cmd)
case ModeTypeLegacy:
cmd = fmt.Sprintf("%s-legacy", cmd)
}
return cmd
}

// Checks if iptables has the "-C" and "--wait" flag
Expand All @@ -589,7 +618,7 @@ func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool)
// getIptablesVersion returns the first three components of the iptables version
// and the operating mode (e.g. nf_tables or legacy)
// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
func extractIptablesVersion(str string) (int, int, int, string, error) {
func extractIptablesVersion(str string) (int, int, int, ModeType, error) {
versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
result := versionMatcher.FindStringSubmatch(str)
if result == nil {
Expand All @@ -611,9 +640,9 @@ func extractIptablesVersion(str string) (int, int, int, string, error) {
return 0, 0, 0, "", err
}

mode := "legacy"
mode := ModeTypeLegacy
if result[4] != "" {
mode = result[4]
mode = ModeType(result[4])
}
return v1, v2, v3, mode, nil
}
Expand Down
138 changes: 90 additions & 48 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ import (
"testing"
)

var (
protos = []Protocol{ProtocolIPv4, ProtocolIPv6}
modes = []ModeType{ModeTypeAuto, ModeTypeLegacy, ModeTypeNFTables}
)

// getProtoName returns the name of the protocol, for use in test names.
func getProtoName(proto Protocol) string {
switch proto {
case ProtocolIPv4:
return "IPv4"
case ProtocolIPv6:
return "IPv6"
default:
panic("unknown protocol")
}
}

func TestProto(t *testing.T) {
ipt, err := New()
if err != nil {
Expand All @@ -34,40 +51,72 @@ func TestProto(t *testing.T) {
t.Fatalf("Expected default protocol IPv4, got %v", ipt.Proto())
}

ip4t, err := NewWithProtocol(ProtocolIPv4)
if err != nil {
t.Fatalf("NewWithProtocol(ProtocolIPv4) failed: %v", err)
}
if ip4t.Proto() != ProtocolIPv4 {
t.Fatalf("Expected protocol IPv4, got %v", ip4t.Proto())
for _, proto := range protos {
protoName := getProtoName(proto)
ipt, err := New(IPFamily(proto))
if err != nil {
t.Fatalf("NewWithProtocol(%s) failed: %v", protoName, err)
}
if ipt.Proto() != proto {
t.Fatalf("Expected protocol %s, got %v", protoName, ipt.Proto())
}
if ipt.mode != ModeTypeAuto {
t.Fatalf("Expected mode auto, got %v", ipt.mode)
}
}

ip6t, err := NewWithProtocol(ProtocolIPv6)
if err != nil {
t.Fatalf("NewWithProtocol(ProtocolIPv6) failed: %v", err)
}
if ip6t.Proto() != ProtocolIPv6 {
t.Fatalf("Expected protocol IPv6, got %v", ip6t.Proto())
for _, proto := range protos {
for _, mode := range modes {
protoName := getProtoName(proto)
ipt, err := New(Mode(mode), IPFamily(proto))
if err != nil {
t.Fatalf("New(Mode(%v), IPFamily(%v)) failed: %v", mode, protoName, err)
}
if ipt.Proto() != proto {
t.Fatalf("Expected protocol %v, got %v", protoName, ipt.Proto())
}
if ipt.mode != mode {
t.Fatalf("Expected mode %v, got %v", mode, ipt.mode)
}
}
}
}

func TestTimeout(t *testing.T) {
ipt, err := New()
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt.timeout != 0 {
t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
}
for _, proto := range protos {
for _, mode := range modes {
ipt, err := New(IPFamily(proto), Mode(mode))
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt.timeout != 0 {
t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
}

ipt2, err := New(Timeout(5))
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt2.timeout != 5 {
t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
ipt2, err := New(Timeout(5))
if err != nil {
t.Fatalf("New failed: %v", err)
}
if ipt2.timeout != 5 {
t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
}
}
}
}

func TestGetIptablesVersionMode(t *testing.T) {
for _, proto := range protos {
for _, mode := range modes {
ipt, err := New(IPFamily(proto), Mode(mode))
if err != nil {
t.Fatalf("New failed: %v", err)
}
_, _, _, getmode := ipt.GetIptablesVersion()
if getmode != mode {
t.Fatalf("Expected mode %v, got %v", mode, mode)
}
}
}
}

func randChain(t *testing.T) string {
Expand All @@ -92,27 +141,20 @@ func contains(list []string, value string) bool {
// features enabled & disabled, to test compatibility.
// We used to test noWait as well, but that was removed as of iptables v1.6.0
func mustTestableIptables() []*IPTables {
ipt, err := New()
if err != nil {
panic(fmt.Sprintf("New failed: %v", err))
}
ip6t, err := NewWithProtocol(ProtocolIPv6)
if err != nil {
panic(fmt.Sprintf("NewWithProtocol(ProtocolIPv6) failed: %v", err))
}
ipts := []*IPTables{ipt, ip6t}

// ensure we check one variant without built-in checking
if ipt.hasCheck {
i := *ipt
i.hasCheck = false
ipts = append(ipts, &i)

i6 := *ip6t
i6.hasCheck = false
ipts = append(ipts, &i6)
} else {
panic("iptables on this machine is too old -- missing -C")
ipts := []*IPTables{}
for _, proto := range protos {
for _, mode := range modes {
ipt, err := New(IPFamily(proto), Mode(mode))
if err != nil {
panic(fmt.Sprintf("New(IPFamily(%v), Mode(%v)) failed: %v", proto, mode, err))
}
if ipt.hasCheck {
ipt.hasCheck = false
ipts = append(ipts, ipt)
} else {
panic("iptables on this machine is too old -- missing -C")
}
}
}
return ipts
}
Expand Down Expand Up @@ -251,7 +293,7 @@ func TestRules(t *testing.T) {
}

func runRulesTests(t *testing.T, ipt *IPTables) {
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto()), ipt.hasWait, ipt.hasCheck)
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto(), ModeTypeAuto), ipt.hasWait, ipt.hasCheck)

var address1, address2, subnet1, subnet2 string
if ipt.Proto() == ProtocolIPv6 {
Expand Down Expand Up @@ -689,7 +731,7 @@ func TestExtractIptablesVersion(t *testing.T) {
t.Fatalf("unexpected err %s", err)
}

if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != ModeType(tt.mode) {
t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
tt.v1, tt.v2, tt.v3, tt.mode,
v1, v2, v3, mode)
Expand Down

0 comments on commit c70a91b

Please sign in to comment.