wlantest: Find a STA entry based on MLO affiliated link addresses

Allow a single STA entry to be found for a non-AP MLD regardless of
which link MAC address was used to transmit/receive it.

Signed-off-by: Jouni Malinen <quic_jouni@quicinc.com>
This commit is contained in:
Jouni Malinen 2022-09-29 13:08:37 +03:00 committed by Jouni Malinen
parent 5d5c2cb2be
commit 228420e2d9
5 changed files with 83 additions and 10 deletions

View file

@ -466,6 +466,8 @@ static void rx_data_bss_prot(struct wlantest *wt,
bss = bss_get(wt, hdr->addr1);
if (bss == NULL)
return;
sta = sta_find_mlo(wt, bss, hdr->addr2);
if (!sta)
sta = sta_get(bss, hdr->addr2);
if (sta)
sta->counters[WLANTEST_STA_COUNTER_PROT_DATA_TX]++;
@ -473,6 +475,8 @@ static void rx_data_bss_prot(struct wlantest *wt,
bss = bss_get(wt, hdr->addr2);
if (bss == NULL)
return;
sta = sta_find_mlo(wt, bss, hdr->addr1);
if (!sta)
sta = sta_get(bss, hdr->addr1);
} else {
bss = bss_get(wt, hdr->addr3);

View file

@ -303,6 +303,7 @@ static void rx_data_eapol_key_2_of_4(struct wlantest *wt, const u8 *dst,
size_t kck_len, mic_len;
u16 key_info, key_data_len;
struct wpa_eapol_ie_parse ie;
int link_id;
wpa_printf(MSG_DEBUG, "EAPOL-Key 2/4 " MACSTR " -> " MACSTR,
MAC2STR(src), MAC2STR(dst));
@ -420,6 +421,19 @@ static void rx_data_eapol_key_2_of_4(struct wlantest *wt, const u8 *dst,
sta->rsnie[0] ? 2 + sta->rsnie[1] : 0);
}
}
for (link_id = 0; link_id < MAX_NUM_MLO_LINKS; link_id++) {
const u8 *addr;
if (!ie.mlo_link[link_id])
continue;
addr = &ie.mlo_link[link_id][RSN_MLO_LINK_KDE_LINK_MAC_INDEX];
wpa_printf(MSG_DEBUG,
"Learned Link ID %u MAC address " MACSTR
" from EAPOL-Key 2/4",
link_id, MAC2STR(addr));
os_memcpy(sta->link_addr[link_id], addr, ETH_ALEN);
}
}

View file

@ -2078,10 +2078,15 @@ static void rx_mgmt_action(struct wlantest *wt, const u8 *data, size_t len,
bss = bss_get(wt, mgmt->bssid);
if (bss == NULL)
return;
if (os_memcmp(mgmt->sa, mgmt->bssid, ETH_ALEN) == 0)
if (os_memcmp(mgmt->sa, mgmt->bssid, ETH_ALEN) == 0) {
sta = sta_find_mlo(wt, bss, mgmt->da);
if (!sta)
sta = sta_get(bss, mgmt->da);
else
} else {
sta = sta_find_mlo(wt, bss, mgmt->sa);
if (!sta)
sta = sta_get(bss, mgmt->sa);
}
if (sta == NULL)
return;
@ -2381,10 +2386,15 @@ static u8 * mgmt_decrypt(struct wlantest *wt, const u8 *data, size_t len,
bss = bss_get(wt, hdr->addr3);
if (bss == NULL)
return mgmt_decrypt_tk(wt, data, len, dlen);
if (os_memcmp(hdr->addr1, hdr->addr3, ETH_ALEN) == 0)
if (os_memcmp(hdr->addr1, hdr->addr3, ETH_ALEN) == 0) {
sta = sta_find_mlo(wt, bss, hdr->addr2);
if (!sta)
sta = sta_get(bss, hdr->addr2);
else
} else {
sta = sta_find_mlo(wt, bss, hdr->addr1);
if (!sta)
sta = sta_get(bss, hdr->addr1);
}
if (sta == NULL || !sta->ptk_set) {
decrypted = mgmt_decrypt_tk(wt, data, len, dlen);
if (!decrypted)

View file

@ -28,6 +28,48 @@ struct wlantest_sta * sta_find(struct wlantest_bss *bss, const u8 *addr)
}
struct wlantest_sta * sta_find_mlo(struct wlantest *wt,
struct wlantest_bss *bss, const u8 *addr)
{
struct wlantest_sta *sta;
struct wlantest_bss *obss;
int link_id;
dl_list_for_each(sta, &bss->sta, struct wlantest_sta, list) {
if (os_memcmp(sta->addr, addr, ETH_ALEN) == 0)
return sta;
}
if (is_zero_ether_addr(addr))
return NULL;
dl_list_for_each(sta, &bss->sta, struct wlantest_sta, list) {
for (link_id = 0; link_id < MAX_NUM_MLO_LINKS; link_id++) {
if (os_memcmp(sta->link_addr[link_id], addr,
ETH_ALEN) == 0)
return sta;
}
}
dl_list_for_each(obss, &wt->bss, struct wlantest_bss, list) {
if (obss == bss)
continue;
dl_list_for_each(sta, &obss->sta, struct wlantest_sta, list) {
if (os_memcmp(sta->addr, addr, ETH_ALEN) == 0)
return sta;
for (link_id = 0; link_id < MAX_NUM_MLO_LINKS;
link_id++) {
if (os_memcmp(sta->link_addr[link_id], addr,
ETH_ALEN) == 0)
return sta;
}
}
}
return NULL;
}
struct wlantest_sta * sta_get(struct wlantest_bss *bss, const u8 *addr)
{
struct wlantest_sta *sta;

View file

@ -56,6 +56,7 @@ struct wlantest_sta {
struct wlantest_bss *bss;
u8 addr[ETH_ALEN];
u8 mld_mac_addr[ETH_ALEN];
u8 link_addr[MAX_NUM_MLO_LINKS][ETH_ALEN];
enum {
STATE1 /* not authenticated */,
STATE2 /* authenticated */,
@ -297,6 +298,8 @@ void pmk_deinit(struct wlantest_pmk *pmk);
void tdls_deinit(struct wlantest_tdls *tdls);
struct wlantest_sta * sta_find(struct wlantest_bss *bss, const u8 *addr);
struct wlantest_sta * sta_find_mlo(struct wlantest *wt,
struct wlantest_bss *bss, const u8 *addr);
struct wlantest_sta * sta_get(struct wlantest_bss *bss, const u8 *addr);
void sta_deinit(struct wlantest_sta *sta);
void sta_update_assoc(struct wlantest_sta *sta,