diff --git a/iptables/iptables.go b/iptables/iptables.go index 6c5bbd7..0440134 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -72,6 +72,7 @@ const ( type IPTables struct { path string + rpath string proto Protocol hasCheck bool hasWait bool @@ -155,6 +156,12 @@ func New(opts ...option) (*IPTables, error) { } ipt.path = path + rpath, err := exec.LookPath(getIptablesRestoreCommand(ipt.proto)) + if err != nil { + return nil, err + } + ipt.rpath = rpath + vstring, err := getIptablesVersionString(path) if err != nil { return nil, fmt.Errorf("could not get iptables version: %v", err) @@ -233,6 +240,23 @@ func (ipt *IPTables) InsertUnique(table, chain string, pos int, rulespec ...stri return nil } +// Restore replaces specified chains and rules in a specific table +// rulesMap is keyed by chain name, and holds slices of rulespecs +// Only chains specified in the map will be flushed and replaced. Other chains will not be affected. +func (ipt *IPTables) Restore(table string, rulesMap map[string][][]string) error { + restoreRules := "*" + table + for chain, rules := range rulesMap { + restoreRules += "\n" + fmt.Sprintf(":%s - [0:0]", strings.ToUpper(chain)) + for _, rule := range rules { + restoreRules += "\n" + fmt.Sprintf("-I %s %s", chain, strings.Join(rule, " ")) + } + } + restoreRules += "\nCOMMIT\n" + cmd := []string{"-n"} + + return ipt.runRestore(cmd, restoreRules) +} + // Append appends rulespec to specified table/chain func (ipt *IPTables) Append(table, chain string, rulespec ...string) error { cmd := append([]string{"-t", table, "-A", chain}, rulespec...) @@ -554,6 +578,57 @@ func (ipt *IPTables) run(args ...string) error { return ipt.runWithOutput(args, nil) } +// runWithOutput runs an iptables command with the given arguments, +// writing any stdout output to the given writer +func (ipt *IPTables) runRestore(args []string, input string) error { + args = append([]string{ipt.rpath}, args...) + if ipt.hasWait { + args = append(args, "--wait") + if ipt.timeout != 0 && ipt.waitSupportSecond { + args = append(args, strconv.Itoa(ipt.timeout)) + } + } else { + fmu, err := newXtablesFileLock() + if err != nil { + return err + } + ul, err := fmu.tryLock() + if err != nil { + syscall.Close(fmu.fd) + return err + } + defer ul.Unlock() + } + + var stderr bytes.Buffer + cmd := exec.Cmd{ + Path: ipt.rpath, + Args: args, + Stderr: &stderr, + } + + stdin, err := cmd.StdinPipe() + if err != nil { + return err + } + + go func() { + defer stdin.Close() + io.WriteString(stdin, input) + }() + + if err := cmd.Run(); err != nil { + switch e := err.(type) { + case *exec.ExitError: + return &Error{*e, cmd, stderr.String(), nil} + default: + return err + } + } + + return nil +} + // runWithOutput runs an iptables command with the given arguments, // writing any stdout output to the given writer func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { @@ -607,6 +682,15 @@ func getIptablesCommand(proto Protocol) string { } } +// getIptablesRestoreCommand returns the correct command for the given protocol, either "iptables" or "ip6tables". +func getIptablesRestoreCommand(proto Protocol) string { + if proto == ProtocolIPv6 { + return "ip6tables-restore" + } else { + return "iptables-restore" + } +} + // Checks if iptables has the "-C" and "--wait" flag func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool) { return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesWaitSupportSecond(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3) diff --git a/iptables/iptables_test.go b/iptables/iptables_test.go index f341e2c..9caba5d 100644 --- a/iptables/iptables_test.go +++ b/iptables/iptables_test.go @@ -165,6 +165,58 @@ func mustTestableIptables() []*IPTables { return ipts } +func TestRestore(t *testing.T) { + for i, ipt := range mustTestableIptables() { + t.Run(fmt.Sprint(i), func(t *testing.T) { + runRestoreTests(t, ipt) + }) + } +} + +func runRestoreTests(t *testing.T, ipt *IPTables) { + t.Logf("testing %s (hasWait=%t, hasCheck=%t)", ipt.path, ipt.hasWait, ipt.hasCheck) + var address1, address2, subnet1, subnet2 string + if ipt.Proto() == ProtocolIPv6 { + address1 = "2001:db8::1" + address2 = "2001:db8::2" + subnet1 = "2001:db8:a::/48" + subnet2 = "2001:db8:b::/48" + } else { + address1 = "203.0.113.1" + address2 = "203.0.113.2" + subnet1 = "192.0.2.0/24" + subnet2 = "198.51.100.0/24" + } + + chain := randChain(t) + rule1 := []string{"-d", subnet1, "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", "DNAT", "--to-destination", fmt.Sprintf("%s:80", address1)} + rule2 := []string{"-d", subnet2, "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", "DNAT", "--to-destination", fmt.Sprintf("%s:80", address2)} + + x := map[string][][]string{ + chain: {rule1, rule2}, + } + err := ipt.Restore("nat", x) + if err != nil { + t.Fatalf("Restore failed: %v", err) + } + + rules, err := ipt.List("nat", chain) + if err != nil { + t.Fatalf("List failed: %v", err) + } + + expected := []string{ + "-N " + chain, + "-A " + chain + " -d " + subnet2 + " -p tcp -m tcp --dport 80 -j DNAT --to-destination " + fmt.Sprintf("%s:80", address2), + "-A " + chain + " -d " + subnet1 + " -p tcp -m tcp --dport 80 -j DNAT --to-destination " + fmt.Sprintf("%s:80", address1), + } + + if !reflect.DeepEqual(rules, expected) { + t.Fatalf("List mismatch: \ngot %#v \nneed %#v", rules, expected) + } + +} + func TestChain(t *testing.T) { for i, ipt := range mustTestableIptables() { t.Run(fmt.Sprint(i), func(t *testing.T) {