#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <net/if.h>
#include <inttypes.h>
#include <linux/bpf.h>
#include <bpf/bpf.h>
#include "bpf_insn.h"
#include <sched.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/if_packet.h>
#include <net/ethernet.h>
#include <arpa/inet.h>
#include <sys/stat.h>
#include <stdbool.h>
#include <stdarg.h>
#include <stdint.h>
#include <sys/mman.h>
#include <pthread.h>
#include <sys/time.h>
#include <sys/resource.h>

char bpf_log_buf[BPF_LOG_BUF_SIZE];

static int prog_load(__u32 idx, __u32 mark, __u32 prio)
{
	struct bpf_insn prog_start[] = {
		BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
	};
	struct bpf_insn prog_end[] = {
		BPF_MOV64_IMM(BPF_REG_0, 0), /* r0 = verdict */
		BPF_EXIT_INSN(),
	};

	struct bpf_insn *prog;
	size_t insns_cnt;
	void *p;
	int ret;

	insns_cnt = sizeof(prog_start) + sizeof(prog_end);

	p = prog = malloc(insns_cnt);
	if (!prog) {
		fprintf(stderr, "Failed to allocate memory for instructions\n");
		return EXIT_FAILURE;
	}

	memcpy(p, prog_start, sizeof(prog_start));
	p += sizeof(prog_start);

	memcpy(p, prog_end, sizeof(prog_end));
	p += sizeof(prog_end);

	insns_cnt /= sizeof(struct bpf_insn);

	ret = bpf_load_program(BPF_PROG_TYPE_CGROUP_SOCK, prog, insns_cnt,
				"GPL", 0, bpf_log_buf, BPF_LOG_BUF_SIZE);

	free(prog);

	return ret;
}

int main(int argc, char **argv)
{
    char* cgrp_path;
    if (argv[1])
    {
        cgrp_path = argv[1];
    }
    else
    {
        cgrp_path = "/sys/fs/cgroup/unified";
    }

    int cg_fd = open(cgrp_path, O_DIRECTORY | O_RDONLY);
	if (cg_fd < 0)
	{
		printf("Failed to open cgroup path: '%s'\n", strerror(errno));
		return EXIT_FAILURE;
	}
    printf("Opened cgroup path =  %s\n", cgrp_path);

    int prog_fd = prog_load(0, 0, 0);
    if( prog_fd < 0 )
    {
        perror("prog_load failed\n");
        return 1;
    }

    printf("Successfuly loaded a BPF_PROG_TYPE_CGROUP_SOCK program\n");

    int ret = bpf_prog_attach(prog_fd, cg_fd,
				      BPF_CGROUP_INET_SOCK_CREATE, 0);
	if (ret < 0)
    {
		printf("Failed to attach prog to cgroup: '%s'\n",
			    strerror(errno));
		return EXIT_FAILURE;
	}
    printf("Successfuly attached the program to BPF_CGROUP_INET_SOCK_CREATE\n");
    return 0;
}
