#define BPF_NO_GLOBAL_DATA
//#define __TARGET_ARCH_x86
#include "vmlinux.h"
#include <string.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
#include <bpf/bpf_core_read.h>
#include "common.h"

char LICENSE[] SEC("license") = "Dual BSD/GPL";


struct ctx_receive_reset {
    __u16 common_type; // unsigned short
    __u8 common_flags; // unsigned char
    __u8 common_count; // unsigned char 
    __s32 pid;         // int

    const void *skaddr;
    __u16 sport; 
    __u16 dport;
    __u16 family;
    __u8 saddr[4];
    __u8 daddr[4];
    __u8 saddr_v6[16];
    __u8 daddr_v6[16];
    __u64 sock_cookie;
};
struct ctx_send_reset {
    __u16 common_type; // unsigned short
    __u8 common_flags; // unsigned char
    __u8 common_count; // unsigned char
    __s32 pid;         // int

    const void *skbaddr;
    const void *skaddr;
    __s32 state;        // int
    __u16 sport; 
    __u16 dport;
    __u8 saddr[4];
    __u8 daddr[4];
    __u8 saddr_v6[16];
    __u8 daddr_v6[16];
};

struct {
    __uint(type, BPF_MAP_TYPE_ARRAY);
    __uint(max_entries, 4096);
    __type(key, 1);
    __type(value, __s32);
} tcp_stats_index SEC(".maps");

struct {
    __uint(type, BPF_MAP_TYPE_ARRAY);
    __uint(max_entries, 4096);
    __type(key, 4);
    __type(value, struct reset);
} tcp_reset_stats SEC(".maps");

struct {
    __uint(type, BPF_MAP_TYPE_ARRAY);
    __uint(max_entries, 1);
    __type(key, __s32);
    __type(value, __u16);
} filter_family SEC(".maps");

struct {
    __uint(type, BPF_MAP_TYPE_ARRAY);
    __uint(max_entries, 1);
    __type(key, __s32);
    __type(value, __u16);
} filter_sport SEC(".maps");

// sudo tcpdump -i any 'tcp[13] & 4 != 0' -n -> filter TCP reset flags

/*
 * This project do not trace any sniffing ports, because, the tracepoint tcp:tcp_send_reset
 * works only for an establish socket, but, if you have a lot of TCP RST, you can have 
 * an issue with your system
 */

/*
 * Identify all tracepoint available
 *   - cat /sys/kernel/tracing/available_events
 * Enable an event:
 *   - echo 'tcp_receive_reset' >> /sys/kernel/tracing/set_event -> important to add the '>>'
 * Docs: https://docs.kernel.org/trace/events.html
 * https://events.linuxfoundation.org/wp-content/uploads/2022/10/elena-zannoni-tracing-tutorial-LF-2021.pdf
 * https://docs.kernel.org/trace/tracepoints.html
 * Why we need to detect RST:
 * When we scan the port, the scanner send an SYN flag and if the port is block, we receive a RST flag:
 * listening on any, link-type LINUX_SLL2 (Linux cooked v2), snapshot length 262144 bytes
10:48:28.531295 lo    In  IP localhost.43961 > localhost.tproxy: Flags [S], seq 2197047013, win 1024, options [mss 1460], length 0
10:48:28.531306 lo    In  IP localhost.tproxy > localhost.43961: Flags [R.], seq 0, ack 2197047014, win 0, length 0
 * But we can also block all receive RST: iptables -I INPUT -p tcp --dport <port> -j REJECT --reject-with tcp-reset
 */

//SEC("tp/tcp_retransmit_synack")
//SEC("tracepoint/tcp/tcp_receive_reset")
SEC("tracepoint/tcp/tcp_send_reset")
int tcp_retransmit(struct ctx_send_reset *ctx){
    struct reset s_reset = {};
    int *index;
    int keys = 0;
    struct sock *sk;
    __u16 family;
    __s16 *f_family;
    __u16 proto;
    int err;

    memset(&s_reset, 0, sizeof(struct reset));

    // Get filter
    sk = (struct sock*)ctx->skaddr;
    f_family = bpf_map_lookup_elem(&filter_family, &keys);
    if (!f_family)
        return 0;

    index = bpf_map_lookup_elem(&tcp_stats_index, &keys);
    if (!index)
        return 0;

    // Get the family of the socket
    bpf_probe_read_kernel(&family, sizeof(family), &sk->__sk_common.skc_family);
    if (family != *f_family)
        return 0;

    // Get and update the index in the map
    *index += 1;

    // Proto type: here it's 6 (TCP)
    bpf_probe_read_kernel(&proto, sizeof(proto), &sk->sk_protocol);

    memcpy(s_reset.saddr, ctx->saddr, 4);
    memcpy(s_reset.daddr, ctx->daddr, 4);

    //bpf_probe_read_kernel(&s_reset.saddr, 4, &ctx->saddr);
    //bpf_probe_read_kernel(&s_reset.daddr, 4, &ctx->daddr);

    s_reset.sport = ctx->sport;
    s_reset.dport = ctx->dport;
    s_reset.family = family;
    s_reset.proto = proto;

    bpf_printk("BPF detected TCP send reset %d %d", s_reset.sport, s_reset.dport);
    bpf_map_update_elem(&tcp_reset_stats, &keys, &s_reset, BPF_ANY);
    return 0;
}