From f93ebd27472302ee292cdbc61c777505675be3bd Mon Sep 17 00:00:00 2001
From: Leon Hwang <hffilwlqm@gmail.com>
Date: Sun, 10 Nov 2024 14:58:26 +0800
Subject: [PATCH] bpf: Remove HAS_KPROBE_MULTI macro

Simplify bpf code.

And use Go to determine loading kprobe or kprobe.multi bpf prog.

Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
---
 .gitignore              |  1 -
 bpf/kprobe_pwru.c       | 34 +++++++++++++++-------------------
 build.go                |  1 -
 internal/pwru/kprobe.go |  5 +++--
 main.go                 | 18 +++++++++++-------
 5 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/.gitignore b/.gitignore
index 834a28eb..b84e6dd5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,5 @@ pwru
 release
 tags
 kprobepwru_*
-kprobemultipwru_*
 featurespwru_*
 !pwru/
diff --git a/bpf/kprobe_pwru.c b/bpf/kprobe_pwru.c
index 71d3382f..8e66152d 100644
--- a/bpf/kprobe_pwru.c
+++ b/bpf/kprobe_pwru.c
@@ -492,7 +492,7 @@ handle_everything(struct sk_buff *skb, void *ctx, struct event_t *event, u64 *_s
 }
 
 static __always_inline int
-kprobe_skb(struct sk_buff *skb, struct pt_regs *ctx, bool has_get_func_ip, u64 *_stackid) {
+kprobe_skb(struct sk_buff *skb, struct pt_regs *ctx, const bool has_get_func_ip, u64 *_stackid) {
 	struct event_t event = {};
 
 	if (!handle_everything(skb, ctx, &event, _stackid, true))
@@ -510,19 +510,17 @@ kprobe_skb(struct sk_buff *skb, struct pt_regs *ctx, bool has_get_func_ip, u64 *
 	return BPF_OK;
 }
 
-#ifdef HAS_KPROBE_MULTI
-#define PWRU_KPROBE_TYPE "kprobe.multi"
-#define PWRU_HAS_GET_FUNC_IP true
-#else
-#define PWRU_KPROBE_TYPE "kprobe"
-#define PWRU_HAS_GET_FUNC_IP false
-#endif /* HAS_KPROBE_MULTI */
-
-#define PWRU_ADD_KPROBE(X)                                                     \
-  SEC(PWRU_KPROBE_TYPE "/skb-" #X)                                             \
-  int kprobe_skb_##X(struct pt_regs *ctx) {                                    \
-    struct sk_buff *skb = (struct sk_buff *) PT_REGS_PARM##X(ctx);             \
-    return kprobe_skb(skb, ctx, PWRU_HAS_GET_FUNC_IP, NULL);                         \
+#define PWRU_ADD_KPROBE(X)                                                  \
+  SEC("kprobe/skb-" #X)                                                     \
+  int kprobe_skb_##X(struct pt_regs *ctx) {                                 \
+    struct sk_buff *skb = (struct sk_buff *) PT_REGS_PARM##X(ctx);          \
+    return kprobe_skb(skb, ctx, false, NULL);                               \
+  }                                                                         \
+                                                                            \
+  SEC("kprobe.multi/skb-" #X)                                               \
+  int kprobe_multi_skb_##X(struct pt_regs *ctx) {                           \
+    struct sk_buff *skb = (struct sk_buff *) PT_REGS_PARM##X(ctx);          \
+    return kprobe_skb(skb, ctx, true, NULL);                                \
   }
 
 PWRU_ADD_KPROBE(1)
@@ -531,21 +529,19 @@ PWRU_ADD_KPROBE(3)
 PWRU_ADD_KPROBE(4)
 PWRU_ADD_KPROBE(5)
 
+#undef PWRU_ADD_KPROBE
+
 SEC("kprobe/skb_by_stackid")
 int kprobe_skb_by_stackid(struct pt_regs *ctx) {
 	u64 stackid = get_stackid(ctx);
 
 	struct sk_buff **skb = bpf_map_lookup_elem(&stackid_skb, &stackid);
 	if (skb && *skb)
-		return kprobe_skb(*skb, ctx, PWRU_HAS_GET_FUNC_IP, &stackid);
+		return kprobe_skb(*skb, ctx, false, &stackid);
 
 	return BPF_OK;
 }
 
-#undef PWRU_KPROBE
-#undef PWRU_HAS_GET_FUNC_IP
-#undef PWRU_KPROBE_TYPE
-
 SEC("kprobe/skb_lifetime_termination")
 int kprobe_skb_lifetime_termination(struct pt_regs *ctx) {
 	struct sk_buff *skb = (typeof(skb)) PT_REGS_PARM1(ctx);
diff --git a/build.go b/build.go
index 126334d6..baa02454 100644
--- a/build.go
+++ b/build.go
@@ -3,7 +3,6 @@
 
 //go:generate sh -c "echo Generating for $TARGET_GOARCH"
 //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -target $TARGET_GOARCH -cc clang -no-strip KProbePWRU ./bpf/kprobe_pwru.c -- -I./bpf/headers -Wno-address-of-packed-member
-//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -target $TARGET_GOARCH -cc clang -no-strip KProbeMultiPWRU ./bpf/kprobe_pwru.c -- -DHAS_KPROBE_MULTI -I./bpf/headers -Wno-address-of-packed-member
 //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -target $TARGET_GOARCH -cc clang -no-strip FeaturesPWRU ./bpf/features.c -- -I./bpf/headers -Wno-address-of-packed-member
 
 package main
diff --git a/internal/pwru/kprobe.go b/internal/pwru/kprobe.go
index 6d8351b4..8dc926bf 100644
--- a/internal/pwru/kprobe.go
+++ b/internal/pwru/kprobe.go
@@ -201,9 +201,10 @@ func AttachKprobeMulti(ctx context.Context, bar *pb.ProgressBar, kprobes []Kprob
 }
 
 func NewKprober(ctx context.Context, funcs Funcs, coll *ebpf.Collection, a2n Addr2Name, useKprobeMulti bool, batch uint) *kprober {
-	msg := "kprobe"
+	msg, probeMethod := "kprobe", "kprobe"
 	if useKprobeMulti {
 		msg = "kprobe-multi"
+		probeMethod = "kprobe_multi"
 	}
 	log.Printf("Attaching kprobes (via %s)...\n", msg)
 
@@ -213,7 +214,7 @@ func NewKprober(ctx context.Context, funcs Funcs, coll *ebpf.Collection, a2n Add
 	pwruKprobes := make([]Kprobe, 0, len(funcs))
 	funcsByPos := GetFuncsByPos(funcs)
 	for pos, fns := range funcsByPos {
-		fn, ok := coll.Programs[fmt.Sprintf("kprobe_skb_%d", pos)]
+		fn, ok := coll.Programs[fmt.Sprintf("%s_skb_%d", probeMethod, pos)]
 		if ok {
 			pwruKprobes = append(pwruKprobes, Kprobe{HookFuncs: fns, Prog: fn})
 		} else {
diff --git a/main.go b/main.go
index 2aa18137..83c0565f 100644
--- a/main.go
+++ b/main.go
@@ -122,17 +122,21 @@ func main() {
 	opts.Programs.LogLevel = ebpf.LogLevelInstruction
 	opts.Programs.LogSize = ebpf.DefaultVerifierLogSize * 100
 
-	var bpfSpec *ebpf.CollectionSpec
-	switch {
-	case (flags.OutputSkb || flags.OutputShinfo) && useKprobeMulti:
-		bpfSpec, err = LoadKProbeMultiPWRU()
-	default:
-		bpfSpec, err = LoadKProbePWRU()
-	}
+	bpfSpec, err := LoadKProbePWRU()
 	if err != nil {
 		log.Fatalf("Failed to load bpf spec: %v", err)
 	}
 
+	if useKprobeMulti {
+		for i := 1; i <= 5; i++ {
+			delete(bpfSpec.Programs, fmt.Sprintf("kprobe_skb_%d", i))
+		}
+	} else {
+		for i := 1; i <= 5; i++ {
+			delete(bpfSpec.Programs, fmt.Sprintf("kprobe_multi_skb_%d", i))
+		}
+	}
+
 	for name, program := range bpfSpec.Programs {
 		// Skip the skb-tracking ones that should not inject pcap-filter.
 		switch name {