From 00e6fc7f5e99de47c2fced24f8d048f8df236705 Mon Sep 17 00:00:00 2001 From: anjmao Date: Fri, 6 Dec 2024 15:38:00 +0200 Subject: [PATCH] WIP --- cmd/agent/daemon/app/app.go | 7 +- pkg/ebpftracer/decoder/decoder.go | 22 +++ pkg/ebpftracer/policy.go | 26 +-- pkg/ebpftracer/policy_filters.go | 104 ++++++++++-- pkg/ebpftracer/tracer.go | 24 +-- pkg/ebpftracer/tracer_decode.go | 259 ++++++++++++++++-------------- 6 files changed, 260 insertions(+), 182 deletions(-) diff --git a/cmd/agent/daemon/app/app.go b/cmd/agent/daemon/app/app.go index bbdecd55..ee01f472 100644 --- a/cmd/agent/daemon/app/app.go +++ b/cmd/agent/daemon/app/app.go @@ -398,11 +398,8 @@ func buildEBPFPolicy(log *logging.Logger, cfg *Config, exporters *state.Exporter } dnsEventPolicy := &ebpftracer.EventPolicy{ - ID: events.NetPacketDNSBase, - FilterGenerator: ebpftracer.FilterAnd( - ebpftracer.FilterEmptyDnsAnswers(log), - ebpftracer.DeduplicateDnsEvents(log, 100, 60*time.Second), - ), + ID: events.NetPacketDNSBase, + PreFilterGenerator: ebpftracer.DnsEventsFilter(log, 100, 60*time.Second), } if cfg.ProcessTree.Enabled { diff --git a/pkg/ebpftracer/decoder/decoder.go b/pkg/ebpftracer/decoder/decoder.go index f92ffe0d..3fdd3504 100644 --- a/pkg/ebpftracer/decoder/decoder.go +++ b/pkg/ebpftracer/decoder/decoder.go @@ -519,8 +519,30 @@ func (decoder *Decoder) ReadAddrTuple() (types.AddrTuple, error) { var errDNSMessageNotComplete = errors.New("received dns packet not complete") +// NOTE: This is not thread safe. Since currently only single go-routine reads the data this is fine. var dnsPacketParser = &layers.DNS{} +func (decoder *Decoder) DecodeDnsLayer(details *packet.PacketDetails) (*layers.DNS, error) { + if details.Proto == packet.SubProtocolTCP { + if len(details.Payload) < 2 { + return nil, errDNSMessageNotComplete + } + + // DNS over TCP prefixes the DNS message with a two octet length field. If the payload is not as big as this specified length, + // then we cannot parse the packet, as part of the DNS message will be send in a later one. + // For more information see https://datatracker.ietf.org/doc/html/rfc1035.html#section-4.2.2 + length := int(binary.BigEndian.Uint16(details.Payload[:2])) + if len(details.Payload)+2 < length { + return nil, errDNSMessageNotComplete + } + details.Payload = details.Payload[2:] + } + if err := dnsPacketParser.DecodeFromBytes(details.Payload, gopacket.NilDecodeFeedback); err != nil { + return nil, err + } + return dnsPacketParser, nil +} + func (decoder *Decoder) ReadProtoDNS() (*types.ProtoDNS, error) { data, err := decoder.ReadMaxByteSliceFromBuff(eventMaxByteSliceBufferSize(events.NetPacketDNSBase)) if err != nil { diff --git a/pkg/ebpftracer/policy.go b/pkg/ebpftracer/policy.go index 4ce9bddc..e1aea3db 100644 --- a/pkg/ebpftracer/policy.go +++ b/pkg/ebpftracer/policy.go @@ -3,6 +3,7 @@ package ebpftracer import ( "time" + "github.com/castai/kvisor/pkg/ebpftracer/decoder" "github.com/castai/kvisor/pkg/ebpftracer/events" "github.com/castai/kvisor/pkg/ebpftracer/types" ) @@ -14,8 +15,9 @@ type Policy struct { Output PolicyOutputConfig } -// PreEventFilter allows for filtering of events coming from the kernel before they are decoded -type PreEventFilter func(ctx *types.EventContext) error +// PreEventFilter allows for filtering of events coming from the kernel before they are decoded. +// Parsed args should be returned if filter passes. +type PreEventFilter func(ctx *types.EventContext, decoder *decoder.Decoder) (types.Args, error) // EventFilterGenerator Produces an pre event filter for each call type PreEventFilterGenerator func() PreEventFilter @@ -47,8 +49,8 @@ type LRUPolicy struct { } type PolicyOutputConfig struct { - RelativeTime bool - ExecHash bool + RelativeTime bool + ExecHash bool ParseArguments bool ParseArgumentsFDs bool @@ -74,19 +76,3 @@ type cgroupEventPolicy struct { preFilter PreEventFilter filter EventFilter } - -func (c *cgroupEventPolicy) allowPre(ctx *types.EventContext) error { - if c.preFilter != nil { - return c.preFilter(ctx) - } - - return nil -} - -func (c *cgroupEventPolicy) allow(event *types.Event) error { - if c.filter != nil { - return c.filter(event) - } - - return nil -} diff --git a/pkg/ebpftracer/policy_filters.go b/pkg/ebpftracer/policy_filters.go index c8651525..5ffb9dff 100644 --- a/pkg/ebpftracer/policy_filters.go +++ b/pkg/ebpftracer/policy_filters.go @@ -6,11 +6,15 @@ import ( "net/netip" "time" + castpb "github.com/castai/kvisor/api/v1/runtime" + "github.com/castai/kvisor/pkg/ebpftracer/decoder" "github.com/castai/kvisor/pkg/ebpftracer/events" "github.com/castai/kvisor/pkg/ebpftracer/types" "github.com/castai/kvisor/pkg/logging" + "github.com/castai/kvisor/pkg/net/packet" "github.com/cespare/xxhash/v2" "github.com/elastic/go-freelru" + "github.com/google/gopacket/layers" "github.com/samber/lo" "golang.org/x/time/rate" ) @@ -56,22 +60,6 @@ func FilterAnd(filtersGenerators ...EventFilterGenerator) EventFilterGenerator { } } -// PreRateLimit creates an pre event filter that limits the amount of events that will be -// processed accoring to the specified limits -func PreRateLimit(spec RateLimitPolicy) PreEventFilterGenerator { - return func() PreEventFilter { - rateLimiter := newRateLimiter(spec) - - return func(ctx *types.EventContext) error { - if rateLimiter.Allow() { - return FilterPass - } - - return FilterErrRateLimit - } - } -} - func RateLimit(spec RateLimitPolicy) EventFilterGenerator { return func() EventFilter { rateLimiter := newRateLimiter(spec) @@ -216,6 +204,90 @@ func DeduplicateDnsEvents(l *logging.Logger, size uint32, ttl time.Duration) Eve } } +func DnsEventsFilter(log *logging.Logger, size uint32, ttl time.Duration) PreEventFilterGenerator { + type cacheValue struct{} + + return func() PreEventFilter { + cache, err := freelru.New[uint32, cacheValue](size, func(key uint32) uint32 { + return key + }) + // err is only ever returned on configuration issues. There is nothing we can really do here, besides + // panicing and surfacing the error to the user. + if err != nil { + panic(err) + } + + cache.SetLifetime(ttl) + + return func(ctx *types.EventContext, decoder *decoder.Decoder) (types.Args, error) { + if ctx.EventID != events.NetPacketDNSBase { + return nil, FilterPass + } + + packetData, err := decoder.ReadMaxByteSliceFromBuff(-1) + if err != nil { + return nil, err + } + + details, err := packet.ExtractPacketDetails(packetData) + if err != nil { + return nil, err + } + + dns, err := decoder.DecodeDnsLayer(&details) + if err != nil { + return nil, err + } + if len(dns.Questions) == 0 { + return nil, FilterErrEmptyDNSResponse + } + + cacheKey := uint32(xxhash.Sum64(dns.Questions[0].Name)) + if cache.Contains(cacheKey) { + if log.IsEnabled(slog.LevelDebug) { + log.WithField("cachekey", string(dns.Questions[0].Name)).Debug("dropping DNS event") + } + return nil, FilterErrDNSDuplicateDetected + } + cache.Add(cacheKey, cacheValue{}) + + result := types.NetPacketDNSBaseArgs{ + Payload: toProtoDNS(&details, dns), + } + return result, FilterPass + } + } +} + +func toProtoDNS(details *packet.PacketDetails, dnsPacketParser *layers.DNS) *castpb.DNS { + pbDNS := &castpb.DNS{ + Answers: make([]*castpb.DNSAnswers, len(dnsPacketParser.Answers)), + Tuple: &castpb.Tuple{ + SrcIp: details.Src.Addr().AsSlice(), + DstIp: details.Dst.Addr().AsSlice(), + SrcPort: uint32(details.Src.Port()), + DstPort: uint32(details.Dst.Port()), + }, + } + + for _, v := range dnsPacketParser.Questions { + pbDNS.DNSQuestionDomain = string(v.Name) + break + } + + for i, v := range dnsPacketParser.Answers { + pbDNS.Answers[i] = &castpb.DNSAnswers{ + Name: string(v.Name), + Type: uint32(v.Type), + Class: uint32(v.Class), + Ttl: v.TTL, + Ip: v.IP, + Cname: string(v.CNAME), + } + } + return pbDNS +} + func isPrivateNetwork(ip netip.Addr) bool { return ip.IsPrivate() || ip.IsLoopback() || diff --git a/pkg/ebpftracer/tracer.go b/pkg/ebpftracer/tracer.go index 10855c31..dd57af86 100644 --- a/pkg/ebpftracer/tracer.go +++ b/pkg/ebpftracer/tracer.go @@ -475,29 +475,7 @@ func (t *Tracer) initTailCall(tailCall TailCall) error { return nil } -func (t *Tracer) allowedByPolicyPre(ctx *types.EventContext) error { - policy := t.getPolicy(ctx.EventID, ctx.CgroupID) - - if policy != nil { - return policy.allowPre(ctx) - } - - // No policy. - return nil -} - -func (t *Tracer) allowedByPolicy(eventID events.ID, cgroupID uint64, event *types.Event) error { - policy := t.getPolicy(eventID, cgroupID) - - if policy != nil { - return policy.allow(event) - } - - // No policy. - return nil -} - -func (t *Tracer) getPolicy(eventID events.ID, cgroupID uint64) *cgroupEventPolicy { +func (t *Tracer) getFilterPolicy(eventID events.ID, cgroupID uint64) *cgroupEventPolicy { t.policyMu.Lock() defer t.policyMu.Unlock() diff --git a/pkg/ebpftracer/tracer_decode.go b/pkg/ebpftracer/tracer_decode.go index 6738c3a2..0a727cdb 100644 --- a/pkg/ebpftracer/tracer_decode.go +++ b/pkg/ebpftracer/tracer_decode.go @@ -19,23 +19,9 @@ import ( "golang.org/x/net/context" ) -// Error indicating that the resulting error was caught from a panic +// ErrPanic indicating that the resulting error was caught from a panic var ErrPanic = errors.New("encountered panic") -func decodeContextAndArgs(ebpfMsgDecoder *decoder.Decoder) (types.EventContext, types.Args, error) { - var eventCtx types.EventContext - if err := ebpfMsgDecoder.DecodeContext(&eventCtx); err != nil { - return types.EventContext{}, nil, fmt.Errorf("decoding context: %w", err) - } - eventId := eventCtx.EventID - parsedArgs, err := decoder.ParseArgs(ebpfMsgDecoder, eventId) - if err != nil { - return types.EventContext{}, nil, fmt.Errorf("parsing event %d args: %w", eventId, err) - } - - return eventCtx, parsedArgs, nil -} - func (t *Tracer) decodeAndExportEvent(ctx context.Context, ebpfMsgDecoder *decoder.Decoder) (rerr error) { defer func() { if perr := recover(); perr != nil { @@ -44,45 +30,42 @@ func (t *Tracer) decodeAndExportEvent(ctx context.Context, ebpfMsgDecoder *decod } }() - eventCtx, parsedArgs, err := decodeContextAndArgs(ebpfMsgDecoder) - if err != nil { - return err + var eventCtx types.EventContext + if err := ebpfMsgDecoder.DecodeContext(&eventCtx); err != nil { + return fmt.Errorf("decoding context: %w", err) } + eventId := eventCtx.EventID def := t.eventsSet[eventCtx.EventID] metrics.AgentPulledEventsBytesTotal.WithLabelValues(def.name).Add(float64(ebpfMsgDecoder.BuffLen())) metrics.AgentPulledEventsTotal.WithLabelValues(def.name).Inc() - // Process special events for cgroup creation and removal. - // These are system events which are not send down via events pipeline. - switch eventId { - case events.CgroupMkdir: - args := parsedArgs.(types.CgroupMkdirArgs) - // We we only care about events from the default cgroup, as cgroup v1 does not have unified cgroups. - if !t.cfg.CgroupClient.IsDefaultHierarchy(args.HierarchyId) { + filterPolicy := t.getFilterPolicy(eventCtx.EventID, eventCtx.CgroupID) + + var err error + var parsedArgs types.Args + if filterPolicy != nil && filterPolicy.preFilter != nil { + parsedArgs, err = filterPolicy.preFilter(&eventCtx, ebpfMsgDecoder) + if err != nil { + metrics.AgentSkippedEventsTotal.WithLabelValues(def.name).Inc() return nil } - t.cfg.CgroupClient.LoadCgroup(args.CgroupId, args.CgroupPath) - if _, err := t.cfg.ContainerClient.AddContainerByCgroupID(context.Background(), args.CgroupId); err != nil { - if errors.Is(err, containers.ErrContainerNotFound) { - err := t.MuteEventsFromCgroup(eventCtx.CgroupID) - if err != nil { - return fmt.Errorf("cannot mute events for cgroup %d: %w", eventCtx.CgroupID, err) - } - return nil - } - t.log.Errorf("cannot add container to cgroup %d: %b", args.CgroupId, err) - } - return nil - case events.CgroupRmdir: - args := parsedArgs.(types.CgroupRmdirArgs) + } - t.queueCgroupForRemoval(args.CgroupId) - err := t.UnmuteEventsFromCgroup(args.CgroupId) + if parsedArgs == nil { + parsedArgs, err = decoder.ParseArgs(ebpfMsgDecoder, eventId) if err != nil { - return fmt.Errorf("cannot remove cgroup %d from mute map: %w", args.CgroupId, err) + return fmt.Errorf("parsing event %d args: %w", eventId, err) } - return nil + } + + // Process special events for cgroup creation and removal. + // These are system events which are not send down via events pipeline. + switch eventId { + case events.CgroupMkdir: + return t.handleCgroupMkdirEvent(&eventCtx, parsedArgs) + case events.CgroupRmdir: + return t.handleCgroupRmdirEvent(parsedArgs) default: } @@ -109,75 +92,13 @@ func (t *Tracer) decodeAndExportEvent(ctx context.Context, ebpfMsgDecoder *decod switch eventId { case events.SchedProcessExec: - if eventCtx.Pid == 1 { - t.cfg.MountNamespacePIDStore.ForceAddToBucket(proc.NamespaceID(eventCtx.MntID), eventCtx.NodeHostPid) - } else { - t.cfg.MountNamespacePIDStore.AddToBucket(proc.NamespaceID(eventCtx.MntID), eventCtx.NodeHostPid) - } - - parentStartTime := time.Duration(0) - if eventCtx.Ppid != 0 { - // We only set the parent start time, if we know the parent PID comes from the same NS. - parentStartTime = time.Duration(eventCtx.ParentStartTime) * time.Nanosecond // nolint:gosec - } - execArgs, ok := parsedArgs.(types.SchedProcessExecArgs) - if !ok { - t.log.Errorf("expected types.SchedProcessExecArgs, but got: %t", parsedArgs) - return nil + if err := t.handleSchedProcessExecEvent(&eventCtx, parsedArgs, container, rawEventTime); err != nil { + return err } - processStartTime := time.Duration(eventCtx.StartTime) * time.Nanosecond // nolint:gosec - - t.cfg.ProcessTreeCollector.ProcessStarted( - system.GetBootTime().Add(time.Duration(rawEventTime)), // nolint:gosec - container.ID, - processtree.Process{ - PID: proc.PID(eventCtx.Pid), - StartTime: processStartTime.Truncate(time.Second), - PPID: proc.PID(eventCtx.Ppid), - ParentStartTime: parentStartTime.Truncate(time.Second), - Args: execArgs.Argv, - FilePath: execArgs.Filepath, - }, - ) - case events.SchedProcessExit, events.ProcessOomKilled: - // We only care about process exits and not threads. - if eventCtx.HostPid == eventCtx.HostTid { - parentStartTime := time.Duration(0) - if eventCtx.Ppid != 0 { - // We only set the parent start time, if we know the parent PID comes from the same NS. - parentStartTime = time.Duration(eventCtx.ParentStartTime) * time.Nanosecond // nolint:gosec - } - - t.cfg.ProcessTreeCollector.ProcessExited( - system.GetBootTime().Add(time.Duration(rawEventTime)), // nolint:gosec - container.ID, - processtree.ToProcessKeyNs( - proc.PID(eventCtx.Pid), - eventCtx.StartTime), - processtree.ToProcessKey(proc.PID(eventCtx.Ppid), parentStartTime), - eventCtx.Ts, - ) - } - + t.handleSchedProcessExitEvent(&eventCtx, container, rawEventTime) case events.SchedProcessFork: - forkArgs := parsedArgs.(types.SchedProcessForkArgs) - - // ChildPID equals ParentPID indicates that the child is probably a thread. We do not care about threads. - if forkArgs.ChildNsPid != forkArgs.ParentNsPid { - parentStartTime := uint64(0) - if forkArgs.UpParentPid != 0 { - parentStartTime = forkArgs.UpParentStartTime - } - - t.cfg.ProcessTreeCollector.ProcessForked( - // We always assume the child start time as the event timestamp for forks. - system.GetBootTime().Add(time.Duration(forkArgs.ChildStartTime)), // nolint:gosec - container.ID, - processtree.ToProcessKeyNs(proc.PID(forkArgs.ParentNsPid), parentStartTime), // nolint:gosec - processtree.ToProcessKeyNs(proc.PID(forkArgs.ChildNsPid), forkArgs.ChildStartTime), //nolint:gosec - ) - } + t.handleSchedProcessForkEvent(&eventCtx, parsedArgs, container) default: } @@ -191,15 +112,11 @@ func (t *Tracer) decodeAndExportEvent(ctx context.Context, ebpfMsgDecoder *decod return nil } - // TODO: Move rate limit based policy to kernel side. - if err := t.allowedByPolicyPre(&eventCtx); err != nil { - metrics.AgentSkippedEventsTotal.WithLabelValues(def.name).Inc() - return nil - } - - if err := t.allowedByPolicy(eventId, eventCtx.CgroupID, event); err != nil { - metrics.AgentSkippedEventsTotal.WithLabelValues(def.name).Inc() - return nil + if filterPolicy != nil && filterPolicy.filter != nil { + if err := filterPolicy.filter(event); err != nil { + metrics.AgentSkippedEventsTotal.WithLabelValues(def.name).Inc() + return nil + } } select { @@ -211,6 +128,112 @@ func (t *Tracer) decodeAndExportEvent(ctx context.Context, ebpfMsgDecoder *decod return nil } +func (t *Tracer) handleCgroupMkdirEvent(eventCtx *types.EventContext, parsedArgs types.Args) error { + args := parsedArgs.(types.CgroupMkdirArgs) + // We we only care about events from the default cgroup, as cgroup v1 does not have unified cgroups. + if !t.cfg.CgroupClient.IsDefaultHierarchy(args.HierarchyId) { + return nil + } + t.cfg.CgroupClient.LoadCgroup(args.CgroupId, args.CgroupPath) + if _, err := t.cfg.ContainerClient.AddContainerByCgroupID(context.Background(), args.CgroupId); err != nil { + if errors.Is(err, containers.ErrContainerNotFound) { + err := t.MuteEventsFromCgroup(eventCtx.CgroupID) + if err != nil { + return fmt.Errorf("cannot mute events for cgroup %d: %w", eventCtx.CgroupID, err) + } + return nil + } + t.log.Errorf("cannot add container to cgroup %d: %b", args.CgroupId, err) + } + return nil +} + +func (t *Tracer) handleCgroupRmdirEvent(parsedArgs types.Args) error { + args := parsedArgs.(types.CgroupRmdirArgs) + + t.queueCgroupForRemoval(args.CgroupId) + err := t.UnmuteEventsFromCgroup(args.CgroupId) + if err != nil { + return fmt.Errorf("cannot remove cgroup %d from mute map: %w", args.CgroupId, err) + } + return nil +} + +func (t *Tracer) handleSchedProcessExecEvent(eventCtx *types.EventContext, parsedArgs types.Args, container *containers.Container, rawEventTime uint64) error { + if eventCtx.Pid == 1 { + t.cfg.MountNamespacePIDStore.ForceAddToBucket(proc.NamespaceID(eventCtx.MntID), eventCtx.NodeHostPid) + } else { + t.cfg.MountNamespacePIDStore.AddToBucket(proc.NamespaceID(eventCtx.MntID), eventCtx.NodeHostPid) + } + + parentStartTime := time.Duration(0) + if eventCtx.Ppid != 0 { + // We only set the parent start time, if we know the parent PID comes from the same NS. + parentStartTime = time.Duration(eventCtx.ParentStartTime) * time.Nanosecond // nolint:gosec + } + execArgs, ok := parsedArgs.(types.SchedProcessExecArgs) + if !ok { + return fmt.Errorf("expected types.SchedProcessExecArgs, but got: %t", parsedArgs) + } + processStartTime := time.Duration(eventCtx.StartTime) * time.Nanosecond // nolint:gosec + + t.cfg.ProcessTreeCollector.ProcessStarted( + system.GetBootTime().Add(time.Duration(rawEventTime)), // nolint:gosec + container.ID, + processtree.Process{ + PID: proc.PID(eventCtx.Pid), + StartTime: processStartTime.Truncate(time.Second), + PPID: proc.PID(eventCtx.Ppid), + ParentStartTime: parentStartTime.Truncate(time.Second), + Args: execArgs.Argv, + FilePath: execArgs.Filepath, + }, + ) + return nil +} + +func (t *Tracer) handleSchedProcessExitEvent(eventCtx *types.EventContext, container *containers.Container, rawEventTime uint64) { + // We only care about process exits and not threads. + if eventCtx.HostPid != eventCtx.HostTid { + return + } + parentStartTime := time.Duration(0) + if eventCtx.Ppid != 0 { + // We only set the parent start time, if we know the parent PID comes from the same NS. + parentStartTime = time.Duration(eventCtx.ParentStartTime) * time.Nanosecond // nolint:gosec + } + + t.cfg.ProcessTreeCollector.ProcessExited( + system.GetBootTime().Add(time.Duration(rawEventTime)), // nolint:gosec + container.ID, + processtree.ToProcessKeyNs( + proc.PID(eventCtx.Pid), + eventCtx.StartTime), + processtree.ToProcessKey(proc.PID(eventCtx.Ppid), parentStartTime), + eventCtx.Ts, + ) +} + +func (t *Tracer) handleSchedProcessForkEvent(eventCtx *types.EventContext, parsedArgs types.Args, container *containers.Container) { + forkArgs := parsedArgs.(types.SchedProcessForkArgs) + + // ChildPID equals ParentPID indicates that the child is probably a thread. We do not care about threads. + if forkArgs.ChildNsPid != forkArgs.ParentNsPid { + parentStartTime := uint64(0) + if forkArgs.UpParentPid != 0 { + parentStartTime = forkArgs.UpParentStartTime + } + + t.cfg.ProcessTreeCollector.ProcessForked( + // We always assume the child start time as the event timestamp for forks. + system.GetBootTime().Add(time.Duration(forkArgs.ChildStartTime)), // nolint:gosec + container.ID, + processtree.ToProcessKeyNs(proc.PID(forkArgs.ParentNsPid), parentStartTime), // nolint:gosec + processtree.ToProcessKeyNs(proc.PID(forkArgs.ChildNsPid), forkArgs.ChildStartTime), //nolint:gosec + ) + } +} + func (t *Tracer) MuteEventsFromCgroup(cgroup uint64) error { t.log.Infof("muting cgroup %d", cgroup) return t.module.objects.IgnoredCgroupsMap.Put(cgroup, cgroup)