Skip to content

Commit

Permalink
Migrate to using golang.org/x/sys/unix over syscall
Browse files Browse the repository at this point in the history
  • Loading branch information
bored-engineer committed Apr 18, 2020
1 parent 75218d0 commit 204af48
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 67 deletions.
4 changes: 2 additions & 2 deletions audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
"regexp"
"strconv"
"strings"
"syscall"

"github.com/spf13/viper"
"golang.org/x/sys/unix"
"gopkg.in/Graylog2/go-gelf.v2/gelf"
)

Expand Down Expand Up @@ -249,7 +249,7 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
// Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation

sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGUSR1)
signal.Notify(sigc, unix.SIGUSR1)

for range sigc {
newWriter, err := createFileOutput(config)
Expand Down
10 changes: 5 additions & 5 deletions audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import (
"os/user"
"path"
"strconv"
"syscall"
"testing"
"time"

"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"golang.org/x/sys/unix"
"gopkg.in/Graylog2/go-gelf.v2/gelf"
)

Expand Down Expand Up @@ -413,7 +413,7 @@ func Test_createOutput(t *testing.T) {
os.Rename(path.Join(os.TempDir(), "go-audit.test.log"), path.Join(os.TempDir(), "go-audit.test.log.rotated"))
_, err = os.Stat(path.Join(os.TempDir(), "go-audit.test.log"))
assert.True(t, os.IsNotExist(err))
syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
unix.Kill(unix.Getpid(), unix.SIGUSR1)
time.Sleep(100 * time.Millisecond)
_, err = os.Stat(path.Join(os.TempDir(), "go-audit.test.log"))
assert.Nil(t, err)
Expand Down Expand Up @@ -565,15 +565,15 @@ func Benchmark_MultiPacketMessage(b *testing.B) {
for i := 0; i < b.N; i++ {
for n := 0; n < len(data); n++ {
nlen := len(data[n])
msg := &syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
msg := &NetlinkMessage{
Header: NetlinkPacket{
Len: Endianness.Uint32(data[n][0:4]),
Type: Endianness.Uint16(data[n][4:6]),
Flags: Endianness.Uint16(data[n][6:8]),
Seq: Endianness.Uint32(data[n][8:12]),
Pid: Endianness.Uint32(data[n][12:16]),
},
Data: data[n][syscall.SizeofNlMsghdr:nlen],
Data: data[n][unix.SizeofNlMsghdr:nlen],
}
marshaller.Consume(msg)
}
Expand Down
45 changes: 26 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"sync/atomic"
"syscall"
"time"
"fmt"

"golang.org/x/sys/unix"
)

// Endianness is an alias for what we assume is the current machine endianness
Expand All @@ -33,42 +34,48 @@ type AuditStatusPayload struct {
}

// NetlinkPacket is an alias to give the header a similar name here
type NetlinkPacket syscall.NlMsghdr
type NetlinkPacket unix.NlMsghdr

// NetlinkMessage is copied from syscall.NetlinkMessage as x/sys/unix does not have it
type NetlinkMessage struct {
Header NetlinkPacket
Data []byte
}

type NetlinkClient struct {
fd int
address syscall.Sockaddr
address unix.Sockaddr
seq uint32
buf []byte
}

// NewNetlinkClient creates a new NetLinkClient and optionally tries to modify the netlink recv buffer
func NewNetlinkClient(recvSize int) (*NetlinkClient, error) {
fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_AUDIT)
fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_AUDIT)
if err != nil {
return nil, fmt.Errorf("Could not create a socket: %s", err)
}

n := &NetlinkClient{
fd: fd,
address: &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: 0, Pid: 0},
address: &unix.SockaddrNetlink{Family: unix.AF_NETLINK, Groups: 0, Pid: 0},
buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH),
}

if err = syscall.Bind(fd, n.address); err != nil {
syscall.Close(fd)
if err = unix.Bind(fd, n.address); err != nil {
unix.Close(fd)
return nil, fmt.Errorf("Could not bind to netlink socket: %s", err)
}

// Set the buffer size if we were asked
if recvSize > 0 {
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, recvSize); err != nil {
el.Println("Failed to set receive buffer size")
}
}

// Print the current receive buffer size
if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil {
if v, err := unix.GetsockoptInt(n.fd, unix.SOL_SOCKET, unix.SO_RCVBUF); err == nil {
l.Println("Socket receive buffer size:", v)
}

Expand Down Expand Up @@ -102,16 +109,16 @@ func (n *NetlinkClient) Send(np *NetlinkPacket, a *AuditStatusPayload) error {
}
}

if err := syscall.Sendto(n.fd, buf.Bytes(), 0, n.address); err != nil {
if err := unix.Sendto(n.fd, buf.Bytes(), 0, n.address); err != nil {
return err
}

return nil
}

// Receive will receive a packet from a netlink socket
func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) {
nlen, _, err := syscall.Recvfrom(n.fd, n.buf, 0)
func (n *NetlinkClient) Receive() (*NetlinkMessage, error) {
nlen, _, err := unix.Recvfrom(n.fd, n.buf, 0)
if err != nil {
return nil, err
}
Expand All @@ -120,15 +127,15 @@ func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) {
return nil, errors.New("Got a 0 length packet")
}

msg := &syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
msg := &NetlinkMessage{
Header: NetlinkPacket{
Len: Endianness.Uint32(n.buf[0:4]),
Type: Endianness.Uint16(n.buf[4:6]),
Flags: Endianness.Uint16(n.buf[6:8]),
Seq: Endianness.Uint32(n.buf[8:12]),
Pid: Endianness.Uint32(n.buf[12:16]),
},
Data: n.buf[syscall.SizeofNlMsghdr:nlen],
Data: n.buf[unix.SizeofNlMsghdr:nlen],
}

return msg, nil
Expand All @@ -139,14 +146,14 @@ func (n *NetlinkClient) KeepConnection() {
payload := &AuditStatusPayload{
Mask: 4,
Enabled: 1,
Pid: uint32(syscall.Getpid()),
Pid: uint32(unix.Getpid()),
//TODO: Failure: http://lxr.free-electrons.com/source/include/uapi/linux/audit.h#L338
}

packet := &NetlinkPacket{
Type: uint16(1001),
Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK,
Pid: uint32(syscall.Getpid()),
Flags: unix.NLM_F_REQUEST | unix.NLM_F_ACK,
Pid: uint32(unix.Getpid()),
}

err := n.Send(packet, payload)
Expand Down
29 changes: 15 additions & 14 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package main
import (
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"os"
"syscall"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/sys/unix"
)

func TestNetlinkClient_KeepConnection(t *testing.T) {
n := makeNelinkClient(t)
defer syscall.Close(n.fd)
defer unix.Close(n.fd)

n.KeepConnection()
msg, err := n.Receive()
Expand All @@ -31,19 +32,19 @@ func TestNetlinkClient_KeepConnection(t *testing.T) {
// Make sure we get errors printed
lb, elb := hookLogger()
defer resetLogger()
syscall.Close(n.fd)
unix.Close(n.fd)
n.KeepConnection()
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
assert.Equal(t, "Error occurred while trying to keep the connection: bad file descriptor\n", elb.String(), "Figured we would have an error")
}

func TestNetlinkClient_SendReceive(t *testing.T) {
var err error
var msg *syscall.NetlinkMessage
var msg *NetlinkMessage

// Build our client
n := makeNelinkClient(t)
defer syscall.Close(n.fd)
defer unix.Close(n.fd)

// Make sure we can encode/decode properly
payload := &AuditStatusPayload{
Expand All @@ -54,7 +55,7 @@ func TestNetlinkClient_SendReceive(t *testing.T) {

packet := &NetlinkPacket{
Type: uint16(1001),
Flags: syscall.NLM_F_REQUEST | syscall.NLM_F_ACK,
Flags: unix.NLM_F_REQUEST | unix.NLM_F_ACK,
Pid: uint32(1006),
}

Expand All @@ -72,12 +73,12 @@ func TestNetlinkClient_SendReceive(t *testing.T) {
assert.Equal(t, uint32(2), msg.Header.Seq, "Header.Seq did not increment")

// Make sure 0 length packets result in an error
syscall.Sendto(n.fd, []byte{}, 0, n.address)
unix.Sendto(n.fd, []byte{}, 0, n.address)
_, err = n.Receive()
assert.Equal(t, "Got a 0 length packet", err.Error(), "Error was incorrect")

// Make sure we get errors from sendto back
syscall.Close(n.fd)
unix.Close(n.fd)
err = n.Send(packet, payload)
assert.Equal(t, "bad file descriptor", err.Error(), "Error was incorrect")

Expand Down Expand Up @@ -110,27 +111,27 @@ func TestNewNetlinkClient(t *testing.T) {
// Helper to make a client listening on a unix socket
func makeNelinkClient(t *testing.T) *NetlinkClient {
os.Remove("go-audit.test.sock")
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_RAW, 0)
fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_RAW, 0)
if err != nil {
t.Fatal("Could not create a socket:", err)
}

n := &NetlinkClient{
fd: fd,
address: &syscall.SockaddrUnix{Name: "go-audit.test.sock"},
address: &unix.SockaddrUnix{Name: "go-audit.test.sock"},
buf: make([]byte, MAX_AUDIT_MESSAGE_LENGTH),
}

if err = syscall.Bind(fd, n.address); err != nil {
syscall.Close(fd)
if err = unix.Bind(fd, n.address); err != nil {
unix.Close(fd)
t.Fatal("Could not bind to netlink socket:", err)
}

return n
}

// Helper to send and then receive a message with the netlink client
func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *syscall.NetlinkMessage {
func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload *AuditStatusPayload) *NetlinkMessage {
err := n.Send(packet, payload)
if err != nil {
t.Fatal("Failed to send:", err)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ require (
github.com/spf13/viper v0.0.0-20170217163817-7538d73b4eb9
github.com/stretchr/testify v1.2.2
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
google.golang.org/grpc v1.25.1 // indirect
gopkg.in/Graylog2/go-gelf.v2 v2.0.0-20180326133423-4dbb9d721348
Expand Down
3 changes: 1 addition & 2 deletions marshaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"os"
"regexp"
"syscall"
"time"
)

Expand Down Expand Up @@ -64,7 +63,7 @@ func NewAuditMarshaller(w *AuditWriter, eventMin uint16, eventMax uint16, trackM
}

// Ingests a netlink message and likely prepares it to be logged
func (a *AuditMarshaller) Consume(nlMsg *syscall.NetlinkMessage) {
func (a *AuditMarshaller) Consume(nlMsg *NetlinkMessage) {
aMsg := NewAuditMessage(nlMsg)

if aMsg.Seq == 0 {
Expand Down
Loading

0 comments on commit 204af48

Please sign in to comment.