Skip to content

Commit

Permalink
xdp-sni: add XDP TLS SNI parsing
Browse files Browse the repository at this point in the history
Initially able to extract firefox and safari clienthello
SNI since the server name extension is the first extension.

can't extract server name extension if the server name
extension is not first extension. now it is fixed by
converting each header length to host order.

1, can't parse SNI if the clienthello is segmented in more
than one tcp segment and SNI happens to be not in first
tcp segments due to large clienthello payload. for example
chrome has large TLS extension list exceeding 1500 bytes MTU,
result in two tcp segments.

2, can't parse SNI if clienthello is IP fragmented

IPS like suricata has ip defrag and tcp reassemble, and SNI
filtering, could workaround 1 and 2.

Signed-off-by: Vincent Li <[email protected]>
  • Loading branch information
vincentmli committed Sep 29, 2024
1 parent 216eaa9 commit b6948fd
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 0 deletions.
9 changes: 9 additions & 0 deletions xdp-sni/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: (GPL-2.0 OR BSD-2-Clause)

XDP_TARGETS := xdp_sni.bpf
BPF_SKEL_TARGETS := $(XDP_TARGETS)
USER_TARGETS := xdp_sni

LIB_DIR = ../lib

include $(LIB_DIR)/common.mk
235 changes: 235 additions & 0 deletions xdp-sni/xdp_sni.bpf.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*
* Copyright (c) 2024, BPFire. All rights reserved.
* Credit to Dylan Reimerink to work out extension for loop.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

#include "vmlinux_local.h"
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#include <linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <linux/udp.h>
#include <linux/tcp.h>
#include <linux/in.h>

#define SERVER_NAME_EXTENSION 0
#define MAX_DOMAIN_SIZE 127

struct domain_name {
struct bpf_lpm_trie_key lpm_key;
char server_name[MAX_DOMAIN_SIZE + 1];
};

struct {
__uint(type, BPF_MAP_TYPE_LPM_TRIE);
__type(key, struct domain_name);
__type(value, __u8);
__uint(max_entries, 10000);
__uint(pinning, LIBBPF_PIN_BY_NAME);
__uint(map_flags, BPF_F_NO_PREALLOC);
} sni_denylist SEC(".maps");


struct extension {
__u16 type;
__u16 len;
} __attribute__((packed));

struct sni_extension {
__u16 list_len;
__u8 type;
__u16 len;
} __attribute__((packed));

static __always_inline void reverse_string(char *str, __u8 len) {
for (int i = 0; i < (len - 1) / 2; i++) {
char temp = str[i];
str[i] = str[len - 1 - i];
str[len - 1 - i] = temp;
}
}

SEC("xdp")
int xdp_tls_sni(struct xdp_md *ctx) {
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
void *cursor = data;

// Parse Ethernet header
struct ethhdr *eth = cursor;
if (cursor + sizeof(*eth) > data_end) goto end;
cursor += sizeof(*eth);

// Only process IPv4 packets
if (eth->h_proto != bpf_htons(ETH_P_IP)) goto end;

// Parse IP header
struct iphdr *ip = cursor;
if (cursor + sizeof(*ip) > data_end) goto end;
cursor += ip->ihl * 4; // IP header length in 32-bit words

// Only process TCP traffic
if (ip->protocol != IPPROTO_TCP) goto end;

// Parse TCP header
struct tcphdr *tcp = cursor;
if (cursor + sizeof(*tcp) > data_end) goto end;
cursor += tcp->doff * 4; // TCP header length in 32-bit words

// Only process traffic on port 443 (HTTPS)
if (tcp->dest != bpf_htons(443)) goto end;

// Check if there's enough data for the TLS ClientHello
if (data_end < cursor + 5) goto end;

// TLS record header
__u8 record_type = *((__u8 *)cursor);
__u16 tls_version = bpf_ntohs(*(__u16 *)(cursor + 1));
__u16 record_length = bpf_ntohs(*(__u16 *)(cursor + 3));

if (record_type != 0x16 || tls_version < 0x0301) goto end; // Only handshake and TLSv1.0+
cursor += 5;

if (record_length > 1024) goto end ;
// Ensure record length doesn't exceed bounds
if (cursor + record_length > data_end) goto end;

// TLS handshake header
if (cursor + 1 > data_end || *((__u8 *)cursor) != 0x01) goto end; // ClientHello
cursor += 4; // Skip handshake message type and length

// Skip TLS version
if (cursor + 2 > data_end) goto end;
cursor += 2;

// Skip random bytes (32 bytes)
if (cursor + 32 > data_end) goto end;
cursor += 32;

// Skip session ID
if (cursor + 1 > data_end) goto end;
__u8 session_id_len = *((__u8 *)cursor);
cursor += 1;
if (cursor + session_id_len > data_end) goto end;
cursor += session_id_len;

// Skip cipher suites
if (cursor + 2 > data_end) goto end;
__u16 cipher_suites_len = bpf_ntohs(*(__u16 *)cursor);
cursor += 2;
if (cipher_suites_len > 254) goto end;
if (cursor + cipher_suites_len > data_end) goto end;
cursor += cipher_suites_len;

// Skip compression methods
if (cursor + 1 > data_end) goto end;

__u8 compression_methods_len = *((__u8 *)cursor);
cursor += 1;
if (cursor + compression_methods_len > data_end) goto end;
cursor += compression_methods_len;

// check bound before get extension_method_len
if (cursor + 2 > data_end) goto end;

__u16 extension_method_len = *(__u16 *)cursor; //here use bpf_ntohs breaks SNI parsing, why?

if (extension_method_len < 0) goto end;

cursor += sizeof(__u16);

for (int i = 0; i < 32; i++)
{
struct extension *ext;
__u16 ext_len = 0;

if (cursor > extension_method_len + data) goto end;

if (data_end < (cursor + sizeof(*ext))) goto end;

ext = (struct extension *)cursor;
ext_len = bpf_ntohs(ext->len);

cursor += sizeof(*ext);

if (ext->type == SERVER_NAME_EXTENSION)
{
struct domain_name dn = {0};

if (data_end < (cursor + sizeof(struct sni_extension))) goto end;

struct sni_extension *sni = (struct sni_extension *)cursor;

cursor += sizeof(struct sni_extension);

__u16 server_name_len = bpf_ntohs(sni->len);

//avoid invalid write to stack R1 off=0 size=1
if (server_name_len >= sizeof(dn.server_name)) goto end;

for (int sn_idx = 0; sn_idx < server_name_len; sn_idx++)
{
// invalid access to packet, off=11 size=1, R5(id=0,off=11,r=11)
// R5 offset is outside of the packet
if (data_end < cursor + sn_idx + 1) goto end;

if (dn.server_name + sizeof(struct domain_name) < dn.server_name + sn_idx) goto end;


dn.server_name[sn_idx] = ((char *)cursor)[sn_idx];
}

dn.server_name[MAX_DOMAIN_SIZE] = '\0';
dn.lpm_key.prefixlen = server_name_len * 8;

bpf_printk("TLS SNI: %s", dn.server_name);

reverse_string(dn.server_name, server_name_len);

if (bpf_map_lookup_elem(&sni_denylist, &dn)) {
bpf_printk("Domain %s found in denylist, dropping packet\n", dn.server_name);
return XDP_DROP;
}

/*
__u8 value = 1;
if (bpf_map_update_elem(&sni_denylist, &dn, &value, BPF_ANY) < 0) {
bpf_printk("Domain %s not updated in denylist\n", dn.server_name);
} else {
bpf_printk("Domain %s updated in denylist\n", dn.server_name);
}
*/

goto end;
}

if (ext_len > 2048) goto end;

if (data_end < cursor + ext_len) goto end;

cursor += ext_len;
}

end:
return XDP_PASS;

}

char _license[] SEC("license") = "GPL";

90 changes: 90 additions & 0 deletions xdp-sni/xdp_sni.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2024, BPFire. All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

#include <bpf/bpf.h>
#include <bpf/libbpf.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>

#define MAX_DOMAIN_SIZE 127

struct domain_name {
struct bpf_lpm_trie_key lpm_key;
char server_name[MAX_DOMAIN_SIZE + 1];
};

// Function to reverse the entire domain string, including the dot
static void reverse_string(char *str) {
int len = strlen(str);
for (int i = 0; i < len / 2; i++) {
char temp = str[i];
str[i] = str[len - i - 1];
str[len - i - 1] = temp;
}
}

int main(int argc, char *argv[]) {
int map_fd;
struct domain_name dn = {0};
__u8 value = 1;

// Check for proper number of arguments
if (argc != 3) {
fprintf(stderr, "Usage: %s <add|delete> <domain>\n", argv[0]);
return 1;
}

// Reverse the input domain
strncpy(dn.server_name, argv[2], MAX_DOMAIN_SIZE);
dn.server_name[MAX_DOMAIN_SIZE] = '\0'; // Ensure null termination
reverse_string(dn.server_name);

// Set the LPM trie key prefix length
dn.lpm_key.prefixlen = strlen(dn.server_name) * 8;

// Open the BPF map
map_fd = bpf_obj_get("/sys/fs/bpf/xdp-sni/sni_denylist");
if (map_fd < 0) {
fprintf(stderr, "Failed to open map: %s\n", strerror(errno));
return 1;
}

// Add or delete the domain based on the first argument
if (strcmp(argv[1], "add") == 0) {
// Update the map with the reversed domain name
if (bpf_map_update_elem(map_fd, &dn, &value, BPF_ANY) != 0) {
fprintf(stderr, "Failed to add domain to map: %s\n", strerror(errno));
return 1;
}
printf("Domain %s (reversed: %s) added to denylist\n", argv[2], dn.server_name);
} else if (strcmp(argv[1], "delete") == 0) {
// Remove the reversed domain from the map
if (bpf_map_delete_elem(map_fd, &dn) != 0) {
fprintf(stderr, "Failed to remove domain from map: %s\n", strerror(errno));
return 1;
}
printf("Domain %s (reversed: %s) removed from denylist\n", argv[2], dn.server_name);
} else {
fprintf(stderr, "Invalid command: %s. Use 'add' or 'delete'.\n", argv[1]);
return 1;
}

return 0;
}

0 comments on commit b6948fd

Please sign in to comment.