From 204af48a9c25dc37dd51141376dbe6e77b515f65 Mon Sep 17 00:00:00 2001 From: Luke Young Date: Sat, 18 Apr 2020 11:50:50 -0700 Subject: [PATCH] Migrate to using golang.org/x/sys/unix over syscall --- audit.go | 4 ++-- audit_test.go | 10 +++++----- client.go | 45 ++++++++++++++++++++++++++------------------- client_test.go | 29 +++++++++++++++-------------- go.mod | 1 + marshaller.go | 3 +-- marshaller_test.go | 35 +++++++++++++++++------------------ parser.go | 5 ++--- parser_test.go | 8 ++++---- 9 files changed, 73 insertions(+), 67 deletions(-) diff --git a/audit.go b/audit.go index d34c84f..447ef35 100644 --- a/audit.go +++ b/audit.go @@ -14,9 +14,9 @@ import ( "regexp" "strconv" "strings" - "syscall" "github.com/spf13/viper" + "golang.org/x/sys/unix" "gopkg.in/Graylog2/go-gelf.v2/gelf" ) @@ -249,7 +249,7 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) { // Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation sigc := make(chan os.Signal, 1) - signal.Notify(sigc, syscall.SIGUSR1) + signal.Notify(sigc, unix.SIGUSR1) for range sigc { newWriter, err := createFileOutput(config) diff --git a/audit_test.go b/audit_test.go index fc39511..bdecd37 100644 --- a/audit_test.go +++ b/audit_test.go @@ -10,12 +10,12 @@ import ( "os/user" "path" "strconv" - "syscall" "testing" "time" "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" "gopkg.in/Graylog2/go-gelf.v2/gelf" ) @@ -413,7 +413,7 @@ func Test_createOutput(t *testing.T) { os.Rename(path.Join(os.TempDir(), "go-audit.test.log"), path.Join(os.TempDir(), "go-audit.test.log.rotated")) _, err = os.Stat(path.Join(os.TempDir(), "go-audit.test.log")) assert.True(t, os.IsNotExist(err)) - syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) + unix.Kill(unix.Getpid(), unix.SIGUSR1) time.Sleep(100 * time.Millisecond) _, err = os.Stat(path.Join(os.TempDir(), "go-audit.test.log")) assert.Nil(t, err) @@ -565,15 +565,15 @@ func Benchmark_MultiPacketMessage(b *testing.B) { for i := 0; i < b.N; i++ { for n := 0; n < len(data); n++ { nlen := len(data[n]) - msg := &syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + msg := &NetlinkMessage{ + Header: NetlinkPacket{ Len: Endianness.Uint32(data[n][0:4]), Type: Endianness.Uint16(data[n][4:6]), Flags: Endianness.Uint16(data[n][6:8]), Seq: Endianness.Uint32(data[n][8:12]), Pid: Endianness.Uint32(data[n][12:16]), }, - Data: data[n][syscall.SizeofNlMsghdr:nlen], + Data: data[n][unix.SizeofNlMsghdr:nlen], } marshaller.Consume(msg) } diff --git a/client.go b/client.go index 7538cb1..559d569 100644 --- a/client.go +++ b/client.go @@ -4,10 +4,11 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "sync/atomic" - "syscall" "time" - "fmt" + + "golang.org/x/sys/unix" ) // Endianness is an alias for what we assume is the current machine endianness @@ -33,42 +34,48 @@ type AuditStatusPayload struct { } // NetlinkPacket is an alias to give the header a similar name here -type NetlinkPacket syscall.NlMsghdr +type NetlinkPacket unix.NlMsghdr + +// NetlinkMessage is copied from syscall.NetlinkMessage as x/sys/unix does not have it +type NetlinkMessage struct { + Header NetlinkPacket + Data []byte +} type NetlinkClient struct { fd int - address syscall.Sockaddr + address unix.Sockaddr seq uint32 buf []byte } // NewNetlinkClient creates a new NetLinkClient and optionally tries to modify the netlink recv buffer func NewNetlinkClient(recvSize int) (*NetlinkClient, error) { - fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_AUDIT) + fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_AUDIT) if err != nil { return nil, fmt.Errorf("Could not create a socket: %s", err) } n := &NetlinkClient{ fd: fd, - address: &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: 0, Pid: 0}, + address: &unix.SockaddrNetlink{Family: unix.AF_NETLINK, Groups: 0, Pid: 0}, buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH), } - if err = syscall.Bind(fd, n.address); err != nil { - syscall.Close(fd) + if err = unix.Bind(fd, n.address); err != nil { + unix.Close(fd) return nil, fmt.Errorf("Could not bind to netlink socket: %s", err) } // Set the buffer size if we were asked if recvSize > 0 { - if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil { + if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, recvSize); err != nil { el.Println("Failed to set receive buffer size") } } // Print the current receive buffer size - if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil { + if v, err := unix.GetsockoptInt(n.fd, unix.SOL_SOCKET, unix.SO_RCVBUF); err == nil { l.Println("Socket receive buffer size:", v) } @@ -102,7 +109,7 @@ func (n *NetlinkClient) Send(np *NetlinkPacket, a *AuditStatusPayload) error { } } - if err := syscall.Sendto(n.fd, buf.Bytes(), 0, n.address); err != nil { + if err := unix.Sendto(n.fd, buf.Bytes(), 0, n.address); err != nil { return err } @@ -110,8 +117,8 @@ func (n *NetlinkClient) Send(np *NetlinkPacket, a *AuditStatusPayload) error { } // Receive will receive a packet from a netlink socket -func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) { - nlen, _, err := syscall.Recvfrom(n.fd, n.buf, 0) +func (n *NetlinkClient) Receive() (*NetlinkMessage, error) { + nlen, _, err := unix.Recvfrom(n.fd, n.buf, 0) if err != nil { return nil, err } @@ -120,15 +127,15 @@ func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) { return nil, errors.New("Got a 0 length packet") } - msg := &syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + msg := &NetlinkMessage{ + Header: NetlinkPacket{ Len: Endianness.Uint32(n.buf[0:4]), Type: Endianness.Uint16(n.buf[4:6]), Flags: Endianness.Uint16(n.buf[6:8]), Seq: Endianness.Uint32(n.buf[8:12]), Pid: Endianness.Uint32(n.buf[12:16]), }, - Data: n.buf[syscall.SizeofNlMsghdr:nlen], + Data: n.buf[unix.SizeofNlMsghdr:nlen], } return msg, nil @@ -139,14 +146,14 @@ func (n *NetlinkClient) KeepConnection() { payload := &AuditStatusPayload{ Mask: 4, Enabled: 1, - Pid: uint32(syscall.Getpid()), + Pid: uint32(unix.Getpid()), //TODO: Failure: http://lxr.free-electrons.com/source/include/uapi/linux/audit.h#L338 } packet := &NetlinkPacket{ Type: uint16(1001), - Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK, - Pid: uint32(syscall.Getpid()), + Flags: unix.NLM_F_REQUEST | unix.NLM_F_ACK, + Pid: uint32(unix.Getpid()), } err := n.Send(packet, payload) diff --git a/client_test.go b/client_test.go index 440fd23..e921409 100644 --- a/client_test.go +++ b/client_test.go @@ -3,15 +3,16 @@ package main import ( "bytes" "encoding/binary" - "github.com/stretchr/testify/assert" "os" - "syscall" "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" ) func TestNetlinkClient_KeepConnection(t *testing.T) { n := makeNelinkClient(t) - defer syscall.Close(n.fd) + defer unix.Close(n.fd) n.KeepConnection() msg, err := n.Receive() @@ -31,7 +32,7 @@ func TestNetlinkClient_KeepConnection(t *testing.T) { // Make sure we get errors printed lb, elb := hookLogger() defer resetLogger() - syscall.Close(n.fd) + unix.Close(n.fd) n.KeepConnection() assert.Equal(t, "", lb.String(), "Got some log lines we did not expect") assert.Equal(t, "Error occurred while trying to keep the connection: bad file descriptor\n", elb.String(), "Figured we would have an error") @@ -39,11 +40,11 @@ func TestNetlinkClient_KeepConnection(t *testing.T) { func TestNetlinkClient_SendReceive(t *testing.T) { var err error - var msg *syscall.NetlinkMessage + var msg *NetlinkMessage // Build our client n := makeNelinkClient(t) - defer syscall.Close(n.fd) + defer unix.Close(n.fd) // Make sure we can encode/decode properly payload := &AuditStatusPayload{ @@ -54,7 +55,7 @@ func TestNetlinkClient_SendReceive(t *testing.T) { packet := &NetlinkPacket{ Type: uint16(1001), - Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK, + Flags: unix.NLM_F_REQUEST | unix.NLM_F_ACK, Pid: uint32(1006), } @@ -72,12 +73,12 @@ func TestNetlinkClient_SendReceive(t *testing.T) { assert.Equal(t, uint32(2), msg.Header.Seq, "Header.Seq did not increment") // Make sure 0 length packets result in an error - syscall.Sendto(n.fd, []byte{}, 0, n.address) + unix.Sendto(n.fd, []byte{}, 0, n.address) _, err = n.Receive() assert.Equal(t, "Got a 0 length packet", err.Error(), "Error was incorrect") // Make sure we get errors from sendto back - syscall.Close(n.fd) + unix.Close(n.fd) err = n.Send(packet, payload) assert.Equal(t, "bad file descriptor", err.Error(), "Error was incorrect") @@ -110,19 +111,19 @@ func TestNewNetlinkClient(t *testing.T) { // Helper to make a client listening on a unix socket func makeNelinkClient(t *testing.T) *NetlinkClient { os.Remove("go-audit.test.sock") - fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_RAW, 0) + fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_RAW, 0) if err != nil { t.Fatal("Could not create a socket:", err) } n := &NetlinkClient{ fd: fd, - address: &syscall.SockaddrUnix{Name: "go-audit.test.sock"}, + address: &unix.SockaddrUnix{Name: "go-audit.test.sock"}, buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH), } - if err = syscall.Bind(fd, n.address); err != nil { - syscall.Close(fd) + if err = unix.Bind(fd, n.address); err != nil { + unix.Close(fd) t.Fatal("Could not bind to netlink socket:", err) } @@ -130,7 +131,7 @@ func makeNelinkClient(t *testing.T) *NetlinkClient { } // Helper to send and then receive a message with the netlink client -func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *syscall.NetlinkMessage { +func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *NetlinkMessage { err := n.Send(packet, payload) if err != nil { t.Fatal("Failed to send:", err) diff --git a/go.mod b/go.mod index e62798d..ad08a4f 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/spf13/viper v0.0.0-20170217163817-7538d73b4eb9 github.com/stretchr/testify v1.2.2 golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect + golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect google.golang.org/grpc v1.25.1 // indirect gopkg.in/Graylog2/go-gelf.v2 v2.0.0-20180326133423-4dbb9d721348 diff --git a/marshaller.go b/marshaller.go index 7d02051..8f5fd98 100644 --- a/marshaller.go +++ b/marshaller.go @@ -3,7 +3,6 @@ package main import ( "os" "regexp" - "syscall" "time" ) @@ -64,7 +63,7 @@ func NewAuditMarshaller(w *AuditWriter, eventMin uint16, eventMax uint16, trackM } // Ingests a netlink message and likely prepares it to be logged -func (a *AuditMarshaller) Consume(nlMsg *syscall.NetlinkMessage) { +func (a *AuditMarshaller) Consume(nlMsg *NetlinkMessage) { aMsg := NewAuditMessage(nlMsg) if aMsg.Seq == 0 { diff --git a/marshaller_test.go b/marshaller_test.go index 4a638aa..9892dd9 100644 --- a/marshaller_test.go +++ b/marshaller_test.go @@ -3,7 +3,6 @@ package main import ( "bytes" "errors" - "syscall" "testing" "time" @@ -19,8 +18,8 @@ func TestAuditMarshaller_Consume(t *testing.T) { m := NewAuditMarshaller(NewAuditWriter(w, 1), uint16(1100), uint16(1399), false, false, 0, []AuditFilter{}, nil) // Flush group on 1320 - m.Consume(&syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + m.Consume(&NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1300), Flags: uint16(0), @@ -30,8 +29,8 @@ func TestAuditMarshaller_Consume(t *testing.T) { Data: []byte("audit(10000001:1): hi there"), }) - m.Consume(&syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + m.Consume(&NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1301), Flags: uint16(0), @@ -52,8 +51,8 @@ func TestAuditMarshaller_Consume(t *testing.T) { // Ignore below 1100 w.Reset() - m.Consume(&syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + m.Consume(&NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1099), Flags: uint16(0), @@ -67,8 +66,8 @@ func TestAuditMarshaller_Consume(t *testing.T) { // Ignore above 1399 w.Reset() - m.Consume(&syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + m.Consume(&NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1400), Flags: uint16(0), @@ -82,8 +81,8 @@ func TestAuditMarshaller_Consume(t *testing.T) { // Ignore sequences of 0 w.Reset() - m.Consume(&syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + m.Consume(&NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1400), Flags: uint16(0), @@ -97,8 +96,8 @@ func TestAuditMarshaller_Consume(t *testing.T) { // Should flush old msgs after 2 seconds w.Reset() - m.Consume(&syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + m.Consume(&NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1300), Flags: uint16(0), @@ -126,8 +125,8 @@ func TestAuditMarshaller_completeMessage(t *testing.T) { // lb, elb := hookLogger() // m := NewAuditMarshaller(NewAuditWriter(&FailWriter{}, 1), uint16(1300), uint16(1399), false, false, 0, []AuditFilter{}) - // m.Consume(&syscall.NetlinkMessage{ - // Header: syscall.NlMsghdr{ + // m.Consume(&NetlinkMessage{ + // Header: NetlinkPacket{ // Len: uint32(44), // Type: uint16(1300), // Flags: uint16(0), @@ -142,9 +141,9 @@ func TestAuditMarshaller_completeMessage(t *testing.T) { // assert.Equal(t, "!", elb.String()) } -func new1320(seq string) *syscall.NetlinkMessage { - return &syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ +func new1320(seq string) *NetlinkMessage { + return &NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1320), Flags: uint16(0), diff --git a/parser.go b/parser.go index 6d51c43..26957ef 100644 --- a/parser.go +++ b/parser.go @@ -5,7 +5,6 @@ import ( "os/user" "strconv" "strings" - "syscall" "time" ) @@ -54,7 +53,7 @@ func NewAuditMessageGroup(am *AuditMessage) *AuditMessageGroup { } // Creates a new go-audit message from a netlink message -func NewAuditMessage(nlm *syscall.NetlinkMessage) *AuditMessage { +func NewAuditMessage(nlm *NetlinkMessage) *AuditMessage { aTime, seq := parseAuditHeader(nlm) return &AuditMessage{ Type: nlm.Header.Type, @@ -65,7 +64,7 @@ func NewAuditMessage(nlm *syscall.NetlinkMessage) *AuditMessage { } // Gets the timestamp and audit sequence id from a netlink message -func parseAuditHeader(msg *syscall.NetlinkMessage) (time string, seq int) { +func parseAuditHeader(msg *NetlinkMessage) (time string, seq int) { headerStop := bytes.Index(msg.Data, headerEndChar) // If the position the header appears to stop is less than the minimum length of a header, bail out if headerStop < HEADER_MIN_LENGTH { diff --git a/parser_test.go b/parser_test.go index 0584675..163d057 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,10 +1,10 @@ package main import ( - "github.com/stretchr/testify/assert" - "syscall" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestAuditConstants(t *testing.T) { @@ -15,8 +15,8 @@ func TestAuditConstants(t *testing.T) { } func TestNewAuditMessage(t *testing.T) { - msg := &syscall.NetlinkMessage{ - Header: syscall.NlMsghdr{ + msg := &NetlinkMessage{ + Header: NetlinkPacket{ Len: uint32(44), Type: uint16(1309), Flags: uint16(0),