Skip to content

Commit

Permalink
Adds ability to use iptables-restore
Browse files Browse the repository at this point in the history
This improves efficiency when adding a lot of rules to a table. Rather
than calling insert or append for each rule, we can execute one iptables
operation to replace them all.

Signed-off-by: Tim Rozet <[email protected]>
  • Loading branch information
trozet committed Mar 28, 2024
1 parent 65c67c9 commit 077e672
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
84 changes: 84 additions & 0 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (

type IPTables struct {
path string
rpath string
proto Protocol
hasCheck bool
hasWait bool
Expand Down Expand Up @@ -155,6 +156,12 @@ func New(opts ...option) (*IPTables, error) {
}
ipt.path = path

rpath, err := exec.LookPath(getIptablesRestoreCommand(ipt.proto))

This comment has been minimized.

Copy link
@tssurya

tssurya Apr 13, 2024

can the rpath be different from path ?

This comment has been minimized.

Copy link
@trozet

trozet Apr 15, 2024

Author

probably not, but I think its better not to assume

if err != nil {
return nil, err
}
ipt.rpath = rpath

vstring, err := getIptablesVersionString(path)
if err != nil {
return nil, fmt.Errorf("could not get iptables version: %v", err)
Expand Down Expand Up @@ -233,6 +240,23 @@ func (ipt *IPTables) InsertUnique(table, chain string, pos int, rulespec ...stri
return nil
}

// Restore replaces specified chains and rules in a specific table
// rulesMap is keyed by chain name, and holds slices of rulespecs
// Only chains specified in the map will be flushed and replaced. Other chains will not be affected.
func (ipt *IPTables) Restore(table string, rulesMap map[string][][]string) error {
restoreRules := "*" + table
for chain, rules := range rulesMap {
restoreRules += "\n" + fmt.Sprintf(":%s - [0:0]", strings.ToUpper(chain))
for _, rule := range rules {
restoreRules += "\n" + fmt.Sprintf("-I %s %s", chain, strings.Join(rule, " "))
}
}
restoreRules += "\nCOMMIT\n"
cmd := []string{"-n"}

return ipt.runRestore(cmd, restoreRules)
}

// Append appends rulespec to specified table/chain
func (ipt *IPTables) Append(table, chain string, rulespec ...string) error {
cmd := append([]string{"-t", table, "-A", chain}, rulespec...)
Expand Down Expand Up @@ -554,6 +578,57 @@ func (ipt *IPTables) run(args ...string) error {
return ipt.runWithOutput(args, nil)
}

// runWithOutput runs an iptables command with the given arguments,
// writing any stdout output to the given writer
func (ipt *IPTables) runRestore(args []string, input string) error {
args = append([]string{ipt.rpath}, args...)
if ipt.hasWait {
args = append(args, "--wait")
if ipt.timeout != 0 && ipt.waitSupportSecond {
args = append(args, strconv.Itoa(ipt.timeout))
}
} else {
fmu, err := newXtablesFileLock()
if err != nil {
return err
}
ul, err := fmu.tryLock()
if err != nil {
syscall.Close(fmu.fd)
return err
}
defer ul.Unlock()
}

var stderr bytes.Buffer
cmd := exec.Cmd{
Path: ipt.rpath,
Args: args,
Stderr: &stderr,
}

stdin, err := cmd.StdinPipe()
if err != nil {
return err
}

go func() {
defer stdin.Close()
io.WriteString(stdin, input)
}()

if err := cmd.Run(); err != nil {
switch e := err.(type) {
case *exec.ExitError:
return &Error{*e, cmd, stderr.String(), nil}
default:
return err
}
}

return nil
}

// runWithOutput runs an iptables command with the given arguments,
// writing any stdout output to the given writer
func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
Expand Down Expand Up @@ -607,6 +682,15 @@ func getIptablesCommand(proto Protocol) string {
}
}

// getIptablesRestoreCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
func getIptablesRestoreCommand(proto Protocol) string {
if proto == ProtocolIPv6 {
return "ip6tables-restore"
} else {
return "iptables-restore"
}
}

// Checks if iptables has the "-C" and "--wait" flag
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool) {
return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesWaitSupportSecond(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
Expand Down
52 changes: 52 additions & 0 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,58 @@ func mustTestableIptables() []*IPTables {
return ipts
}

func TestRestore(t *testing.T) {
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runRestoreTests(t, ipt)
})
}
}

func runRestoreTests(t *testing.T, ipt *IPTables) {
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", ipt.path, ipt.hasWait, ipt.hasCheck)
var address1, address2, subnet1, subnet2 string
if ipt.Proto() == ProtocolIPv6 {
address1 = "2001:db8::1"
address2 = "2001:db8::2"
subnet1 = "2001:db8:a::/48"
subnet2 = "2001:db8:b::/48"
} else {
address1 = "203.0.113.1"
address2 = "203.0.113.2"
subnet1 = "192.0.2.0/24"
subnet2 = "198.51.100.0/24"
}

chain := randChain(t)
rule1 := []string{"-d", subnet1, "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", "DNAT", "--to-destination", fmt.Sprintf("%s:80", address1)}
rule2 := []string{"-d", subnet2, "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", "DNAT", "--to-destination", fmt.Sprintf("%s:80", address2)}

x := map[string][][]string{
chain: {rule1, rule2},
}
err := ipt.Restore("nat", x)
if err != nil {
t.Fatalf("Restore failed: %v", err)
}

rules, err := ipt.List("nat", chain)
if err != nil {
t.Fatalf("List failed: %v", err)
}

expected := []string{
"-N " + chain,
"-A " + chain + " -d " + subnet2 + " -p tcp -m tcp --dport 80 -j DNAT --to-destination " + fmt.Sprintf("%s:80", address2),
"-A " + chain + " -d " + subnet1 + " -p tcp -m tcp --dport 80 -j DNAT --to-destination " + fmt.Sprintf("%s:80", address1),
}

if !reflect.DeepEqual(rules, expected) {
t.Fatalf("List mismatch: \ngot %#v \nneed %#v", rules, expected)
}

}

func TestChain(t *testing.T) {
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
Expand Down

0 comments on commit 077e672

Please sign in to comment.