Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ability to use iptables-restore #124

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (

type IPTables struct {
path string
rpath string
proto Protocol
hasCheck bool
hasWait bool
Expand Down Expand Up @@ -155,6 +156,12 @@ func New(opts ...option) (*IPTables, error) {
}
ipt.path = path

rpath, err := exec.LookPath(getIptablesRestoreCommand(ipt.proto))
if err != nil {
return nil, err
}
ipt.rpath = rpath

vstring, err := getIptablesVersionString(path)
if err != nil {
return nil, fmt.Errorf("could not get iptables version: %v", err)
Expand Down Expand Up @@ -233,6 +240,23 @@ func (ipt *IPTables) InsertUnique(table, chain string, pos int, rulespec ...stri
return nil
}

// Restore replaces specified chains and rules in a specific table
// rulesMap is keyed by chain name, and holds slices of rulespecs
// Only chains specified in the map will be flushed and replaced. Other chains will not be affected.
func (ipt *IPTables) Restore(table string, rulesMap map[string][][]string) error {
restoreRules := "*" + table
for chain, rules := range rulesMap {
restoreRules += "\n" + fmt.Sprintf(":%s - [0:0]", strings.ToUpper(chain))
for _, rule := range rules {
restoreRules += "\n" + fmt.Sprintf("-I %s %s", chain, strings.Join(rule, " "))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I assumed Insert for the rules...but maybe it would be more appropriate to use Append. Another option is make it part of the rule, or make it a parameter to the function. I dont have any preference.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally expect Append but perhaps it could be documented that this is the behavior

}
}
restoreRules += "\nCOMMIT\n"
cmd := []string{"-n"}

return ipt.runRestore(cmd, restoreRules)
}

// Append appends rulespec to specified table/chain
func (ipt *IPTables) Append(table, chain string, rulespec ...string) error {
cmd := append([]string{"-t", table, "-A", chain}, rulespec...)
Expand Down Expand Up @@ -554,6 +578,57 @@ func (ipt *IPTables) run(args ...string) error {
return ipt.runWithOutput(args, nil)
}

// runWithOutput runs an iptables command with the given arguments,
// writing any stdout output to the given writer
func (ipt *IPTables) runRestore(args []string, input string) error {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if the maintainers are planning to merge this PR but it'd be nice to make this function public or maybe create a wrapper RestoreRaw(args []string, input string)

args = append([]string{ipt.rpath}, args...)
if ipt.hasWait {
args = append(args, "--wait")
if ipt.timeout != 0 && ipt.waitSupportSecond {
args = append(args, strconv.Itoa(ipt.timeout))
}
} else {
fmu, err := newXtablesFileLock()
if err != nil {
return err
}
ul, err := fmu.tryLock()
if err != nil {
syscall.Close(fmu.fd)
return err
}
defer ul.Unlock()
}

var stderr bytes.Buffer
cmd := exec.Cmd{
Path: ipt.rpath,
Args: args,
Stderr: &stderr,
}

stdin, err := cmd.StdinPipe()
if err != nil {
return err
}

go func() {
defer stdin.Close()
io.WriteString(stdin, input)
}()

if err := cmd.Run(); err != nil {
switch e := err.(type) {
case *exec.ExitError:
return &Error{*e, cmd, stderr.String(), nil}
default:
return err
}
}

return nil
}

// runWithOutput runs an iptables command with the given arguments,
// writing any stdout output to the given writer
func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
Expand Down Expand Up @@ -607,6 +682,15 @@ func getIptablesCommand(proto Protocol) string {
}
}

// getIptablesRestoreCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
func getIptablesRestoreCommand(proto Protocol) string {
if proto == ProtocolIPv6 {
return "ip6tables-restore"
} else {
return "iptables-restore"
}
}

// Checks if iptables has the "-C" and "--wait" flag
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool) {
return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesWaitSupportSecond(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
Expand Down
52 changes: 52 additions & 0 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,58 @@ func mustTestableIptables() []*IPTables {
return ipts
}

func TestRestore(t *testing.T) {
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runRestoreTests(t, ipt)
})
}
}

func runRestoreTests(t *testing.T, ipt *IPTables) {
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", ipt.path, ipt.hasWait, ipt.hasCheck)
var address1, address2, subnet1, subnet2 string
if ipt.Proto() == ProtocolIPv6 {
address1 = "2001:db8::1"
address2 = "2001:db8::2"
subnet1 = "2001:db8:a::/48"
subnet2 = "2001:db8:b::/48"
} else {
address1 = "203.0.113.1"
address2 = "203.0.113.2"
subnet1 = "192.0.2.0/24"
subnet2 = "198.51.100.0/24"
}

chain := randChain(t)
rule1 := []string{"-d", subnet1, "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", "DNAT", "--to-destination", fmt.Sprintf("%s:80", address1)}
rule2 := []string{"-d", subnet2, "-p", "tcp", "-m", "tcp", "--dport", "80", "-j", "DNAT", "--to-destination", fmt.Sprintf("%s:80", address2)}

x := map[string][][]string{
chain: {rule1, rule2},
}
err := ipt.Restore("nat", x)
if err != nil {
t.Fatalf("Restore failed: %v", err)
}

rules, err := ipt.List("nat", chain)
if err != nil {
t.Fatalf("List failed: %v", err)
}

expected := []string{
"-N " + chain,
"-A " + chain + " -d " + subnet2 + " -p tcp -m tcp --dport 80 -j DNAT --to-destination " + fmt.Sprintf("%s:80", address2),
"-A " + chain + " -d " + subnet1 + " -p tcp -m tcp --dport 80 -j DNAT --to-destination " + fmt.Sprintf("%s:80", address1),
}

if !reflect.DeepEqual(rules, expected) {
t.Fatalf("List mismatch: \ngot %#v \nneed %#v", rules, expected)
}

}

func TestChain(t *testing.T) {
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
Expand Down