Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
anjmao committed Dec 6, 2024
1 parent 94517a0 commit 00e6fc7
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 182 deletions.
7 changes: 2 additions & 5 deletions cmd/agent/daemon/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,8 @@ func buildEBPFPolicy(log *logging.Logger, cfg *Config, exporters *state.Exporter
}

dnsEventPolicy := &ebpftracer.EventPolicy{
ID: events.NetPacketDNSBase,
FilterGenerator: ebpftracer.FilterAnd(
ebpftracer.FilterEmptyDnsAnswers(log),
ebpftracer.DeduplicateDnsEvents(log, 100, 60*time.Second),
),
ID: events.NetPacketDNSBase,
PreFilterGenerator: ebpftracer.DnsEventsFilter(log, 100, 60*time.Second),
}

if cfg.ProcessTree.Enabled {
Expand Down
22 changes: 22 additions & 0 deletions pkg/ebpftracer/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,30 @@ func (decoder *Decoder) ReadAddrTuple() (types.AddrTuple, error) {

var errDNSMessageNotComplete = errors.New("received dns packet not complete")

// NOTE: This is not thread safe. Since currently only single go-routine reads the data this is fine.
var dnsPacketParser = &layers.DNS{}

func (decoder *Decoder) DecodeDnsLayer(details *packet.PacketDetails) (*layers.DNS, error) {
if details.Proto == packet.SubProtocolTCP {
if len(details.Payload) < 2 {
return nil, errDNSMessageNotComplete
}

// DNS over TCP prefixes the DNS message with a two octet length field. If the payload is not as big as this specified length,
// then we cannot parse the packet, as part of the DNS message will be send in a later one.
// For more information see https://datatracker.ietf.org/doc/html/rfc1035.html#section-4.2.2
length := int(binary.BigEndian.Uint16(details.Payload[:2]))
if len(details.Payload)+2 < length {
return nil, errDNSMessageNotComplete
}
details.Payload = details.Payload[2:]
}
if err := dnsPacketParser.DecodeFromBytes(details.Payload, gopacket.NilDecodeFeedback); err != nil {
return nil, err
}
return dnsPacketParser, nil
}

func (decoder *Decoder) ReadProtoDNS() (*types.ProtoDNS, error) {
data, err := decoder.ReadMaxByteSliceFromBuff(eventMaxByteSliceBufferSize(events.NetPacketDNSBase))
if err != nil {
Expand Down
26 changes: 6 additions & 20 deletions pkg/ebpftracer/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ebpftracer
import (
"time"

"github.com/castai/kvisor/pkg/ebpftracer/decoder"
"github.com/castai/kvisor/pkg/ebpftracer/events"
"github.com/castai/kvisor/pkg/ebpftracer/types"
)
Expand All @@ -14,8 +15,9 @@ type Policy struct {
Output PolicyOutputConfig
}

// PreEventFilter allows for filtering of events coming from the kernel before they are decoded
type PreEventFilter func(ctx *types.EventContext) error
// PreEventFilter allows for filtering of events coming from the kernel before they are decoded.
// Parsed args should be returned if filter passes.
type PreEventFilter func(ctx *types.EventContext, decoder *decoder.Decoder) (types.Args, error)

// EventFilterGenerator Produces an pre event filter for each call
type PreEventFilterGenerator func() PreEventFilter
Expand Down Expand Up @@ -47,8 +49,8 @@ type LRUPolicy struct {
}

type PolicyOutputConfig struct {
RelativeTime bool
ExecHash bool
RelativeTime bool
ExecHash bool

ParseArguments bool
ParseArgumentsFDs bool
Expand All @@ -74,19 +76,3 @@ type cgroupEventPolicy struct {
preFilter PreEventFilter
filter EventFilter
}

func (c *cgroupEventPolicy) allowPre(ctx *types.EventContext) error {
if c.preFilter != nil {
return c.preFilter(ctx)
}

return nil
}

func (c *cgroupEventPolicy) allow(event *types.Event) error {
if c.filter != nil {
return c.filter(event)
}

return nil
}
104 changes: 88 additions & 16 deletions pkg/ebpftracer/policy_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ import (
"net/netip"
"time"

castpb "github.com/castai/kvisor/api/v1/runtime"
"github.com/castai/kvisor/pkg/ebpftracer/decoder"
"github.com/castai/kvisor/pkg/ebpftracer/events"
"github.com/castai/kvisor/pkg/ebpftracer/types"
"github.com/castai/kvisor/pkg/logging"
"github.com/castai/kvisor/pkg/net/packet"
"github.com/cespare/xxhash/v2"
"github.com/elastic/go-freelru"
"github.com/google/gopacket/layers"
"github.com/samber/lo"
"golang.org/x/time/rate"
)
Expand Down Expand Up @@ -56,22 +60,6 @@ func FilterAnd(filtersGenerators ...EventFilterGenerator) EventFilterGenerator {
}
}

// PreRateLimit creates an pre event filter that limits the amount of events that will be
// processed accoring to the specified limits
func PreRateLimit(spec RateLimitPolicy) PreEventFilterGenerator {
return func() PreEventFilter {
rateLimiter := newRateLimiter(spec)

return func(ctx *types.EventContext) error {
if rateLimiter.Allow() {
return FilterPass
}

return FilterErrRateLimit
}
}
}

func RateLimit(spec RateLimitPolicy) EventFilterGenerator {
return func() EventFilter {
rateLimiter := newRateLimiter(spec)
Expand Down Expand Up @@ -216,6 +204,90 @@ func DeduplicateDnsEvents(l *logging.Logger, size uint32, ttl time.Duration) Eve
}
}

func DnsEventsFilter(log *logging.Logger, size uint32, ttl time.Duration) PreEventFilterGenerator {
type cacheValue struct{}

return func() PreEventFilter {
cache, err := freelru.New[uint32, cacheValue](size, func(key uint32) uint32 {
return key
})
// err is only ever returned on configuration issues. There is nothing we can really do here, besides
// panicing and surfacing the error to the user.
if err != nil {
panic(err)
}

cache.SetLifetime(ttl)

return func(ctx *types.EventContext, decoder *decoder.Decoder) (types.Args, error) {
if ctx.EventID != events.NetPacketDNSBase {
return nil, FilterPass
}

packetData, err := decoder.ReadMaxByteSliceFromBuff(-1)
if err != nil {
return nil, err
}

details, err := packet.ExtractPacketDetails(packetData)
if err != nil {
return nil, err
}

dns, err := decoder.DecodeDnsLayer(&details)
if err != nil {
return nil, err
}
if len(dns.Questions) == 0 {
return nil, FilterErrEmptyDNSResponse
}

cacheKey := uint32(xxhash.Sum64(dns.Questions[0].Name))

Check failure on line 245 in pkg/ebpftracer/policy_filters.go

View workflow job for this annotation

GitHub Actions / Build

G115: integer overflow conversion uint64 -> uint32 (gosec)
if cache.Contains(cacheKey) {
if log.IsEnabled(slog.LevelDebug) {
log.WithField("cachekey", string(dns.Questions[0].Name)).Debug("dropping DNS event")
}
return nil, FilterErrDNSDuplicateDetected
}
cache.Add(cacheKey, cacheValue{})

result := types.NetPacketDNSBaseArgs{
Payload: toProtoDNS(&details, dns),
}
return result, FilterPass
}
}
}

func toProtoDNS(details *packet.PacketDetails, dnsPacketParser *layers.DNS) *castpb.DNS {
pbDNS := &castpb.DNS{
Answers: make([]*castpb.DNSAnswers, len(dnsPacketParser.Answers)),
Tuple: &castpb.Tuple{
SrcIp: details.Src.Addr().AsSlice(),
DstIp: details.Dst.Addr().AsSlice(),
SrcPort: uint32(details.Src.Port()),
DstPort: uint32(details.Dst.Port()),
},
}

for _, v := range dnsPacketParser.Questions {
pbDNS.DNSQuestionDomain = string(v.Name)
break
}

for i, v := range dnsPacketParser.Answers {
pbDNS.Answers[i] = &castpb.DNSAnswers{
Name: string(v.Name),
Type: uint32(v.Type),
Class: uint32(v.Class),
Ttl: v.TTL,
Ip: v.IP,
Cname: string(v.CNAME),
}
}
return pbDNS
}

func isPrivateNetwork(ip netip.Addr) bool {
return ip.IsPrivate() ||
ip.IsLoopback() ||
Expand Down
24 changes: 1 addition & 23 deletions pkg/ebpftracer/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,29 +475,7 @@ func (t *Tracer) initTailCall(tailCall TailCall) error {
return nil
}

func (t *Tracer) allowedByPolicyPre(ctx *types.EventContext) error {
policy := t.getPolicy(ctx.EventID, ctx.CgroupID)

if policy != nil {
return policy.allowPre(ctx)
}

// No policy.
return nil
}

func (t *Tracer) allowedByPolicy(eventID events.ID, cgroupID uint64, event *types.Event) error {
policy := t.getPolicy(eventID, cgroupID)

if policy != nil {
return policy.allow(event)
}

// No policy.
return nil
}

func (t *Tracer) getPolicy(eventID events.ID, cgroupID uint64) *cgroupEventPolicy {
func (t *Tracer) getFilterPolicy(eventID events.ID, cgroupID uint64) *cgroupEventPolicy {
t.policyMu.Lock()
defer t.policyMu.Unlock()

Expand Down
Loading

0 comments on commit 00e6fc7

Please sign in to comment.