diff --git a/wlantest/rx_data.c b/wlantest/rx_data.c index fc3dee49c..64573c044 100644 --- a/wlantest/rx_data.c +++ b/wlantest/rx_data.c @@ -466,14 +466,18 @@ static void rx_data_bss_prot(struct wlantest *wt, bss = bss_get(wt, hdr->addr1); if (bss == NULL) return; - sta = sta_get(bss, hdr->addr2); + sta = sta_find_mlo(wt, bss, hdr->addr2); + if (!sta) + sta = sta_get(bss, hdr->addr2); if (sta) sta->counters[WLANTEST_STA_COUNTER_PROT_DATA_TX]++; } else if (fc & WLAN_FC_FROMDS) { bss = bss_get(wt, hdr->addr2); if (bss == NULL) return; - sta = sta_get(bss, hdr->addr1); + sta = sta_find_mlo(wt, bss, hdr->addr1); + if (!sta) + sta = sta_get(bss, hdr->addr1); } else { bss = bss_get(wt, hdr->addr3); if (bss == NULL) diff --git a/wlantest/rx_eapol.c b/wlantest/rx_eapol.c index 56566a6c5..55a06d1ac 100644 --- a/wlantest/rx_eapol.c +++ b/wlantest/rx_eapol.c @@ -303,6 +303,7 @@ static void rx_data_eapol_key_2_of_4(struct wlantest *wt, const u8 *dst, size_t kck_len, mic_len; u16 key_info, key_data_len; struct wpa_eapol_ie_parse ie; + int link_id; wpa_printf(MSG_DEBUG, "EAPOL-Key 2/4 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); @@ -420,6 +421,19 @@ static void rx_data_eapol_key_2_of_4(struct wlantest *wt, const u8 *dst, sta->rsnie[0] ? 2 + sta->rsnie[1] : 0); } } + + for (link_id = 0; link_id < MAX_NUM_MLO_LINKS; link_id++) { + const u8 *addr; + + if (!ie.mlo_link[link_id]) + continue; + addr = &ie.mlo_link[link_id][RSN_MLO_LINK_KDE_LINK_MAC_INDEX]; + wpa_printf(MSG_DEBUG, + "Learned Link ID %u MAC address " MACSTR + " from EAPOL-Key 2/4", + link_id, MAC2STR(addr)); + os_memcpy(sta->link_addr[link_id], addr, ETH_ALEN); + } } diff --git a/wlantest/rx_mgmt.c b/wlantest/rx_mgmt.c index b6703ef44..a44961561 100644 --- a/wlantest/rx_mgmt.c +++ b/wlantest/rx_mgmt.c @@ -2078,10 +2078,15 @@ static void rx_mgmt_action(struct wlantest *wt, const u8 *data, size_t len, bss = bss_get(wt, mgmt->bssid); if (bss == NULL) return; - if (os_memcmp(mgmt->sa, mgmt->bssid, ETH_ALEN) == 0) - sta = sta_get(bss, mgmt->da); - else - sta = sta_get(bss, mgmt->sa); + if (os_memcmp(mgmt->sa, mgmt->bssid, ETH_ALEN) == 0) { + sta = sta_find_mlo(wt, bss, mgmt->da); + if (!sta) + sta = sta_get(bss, mgmt->da); + } else { + sta = sta_find_mlo(wt, bss, mgmt->sa); + if (!sta) + sta = sta_get(bss, mgmt->sa); + } if (sta == NULL) return; @@ -2381,10 +2386,15 @@ static u8 * mgmt_decrypt(struct wlantest *wt, const u8 *data, size_t len, bss = bss_get(wt, hdr->addr3); if (bss == NULL) return mgmt_decrypt_tk(wt, data, len, dlen); - if (os_memcmp(hdr->addr1, hdr->addr3, ETH_ALEN) == 0) - sta = sta_get(bss, hdr->addr2); - else - sta = sta_get(bss, hdr->addr1); + if (os_memcmp(hdr->addr1, hdr->addr3, ETH_ALEN) == 0) { + sta = sta_find_mlo(wt, bss, hdr->addr2); + if (!sta) + sta = sta_get(bss, hdr->addr2); + } else { + sta = sta_find_mlo(wt, bss, hdr->addr1); + if (!sta) + sta = sta_get(bss, hdr->addr1); + } if (sta == NULL || !sta->ptk_set) { decrypted = mgmt_decrypt_tk(wt, data, len, dlen); if (!decrypted) diff --git a/wlantest/sta.c b/wlantest/sta.c index 02ecb78c3..c390e0022 100644 --- a/wlantest/sta.c +++ b/wlantest/sta.c @@ -28,6 +28,48 @@ struct wlantest_sta * sta_find(struct wlantest_bss *bss, const u8 *addr) } +struct wlantest_sta * sta_find_mlo(struct wlantest *wt, + struct wlantest_bss *bss, const u8 *addr) +{ + struct wlantest_sta *sta; + struct wlantest_bss *obss; + int link_id; + + dl_list_for_each(sta, &bss->sta, struct wlantest_sta, list) { + if (os_memcmp(sta->addr, addr, ETH_ALEN) == 0) + return sta; + } + + if (is_zero_ether_addr(addr)) + return NULL; + + dl_list_for_each(sta, &bss->sta, struct wlantest_sta, list) { + for (link_id = 0; link_id < MAX_NUM_MLO_LINKS; link_id++) { + if (os_memcmp(sta->link_addr[link_id], addr, + ETH_ALEN) == 0) + return sta; + } + } + + dl_list_for_each(obss, &wt->bss, struct wlantest_bss, list) { + if (obss == bss) + continue; + dl_list_for_each(sta, &obss->sta, struct wlantest_sta, list) { + if (os_memcmp(sta->addr, addr, ETH_ALEN) == 0) + return sta; + for (link_id = 0; link_id < MAX_NUM_MLO_LINKS; + link_id++) { + if (os_memcmp(sta->link_addr[link_id], addr, + ETH_ALEN) == 0) + return sta; + } + } + } + + return NULL; +} + + struct wlantest_sta * sta_get(struct wlantest_bss *bss, const u8 *addr) { struct wlantest_sta *sta; diff --git a/wlantest/wlantest.h b/wlantest/wlantest.h index 9731e7e6b..5179d48e6 100644 --- a/wlantest/wlantest.h +++ b/wlantest/wlantest.h @@ -56,6 +56,7 @@ struct wlantest_sta { struct wlantest_bss *bss; u8 addr[ETH_ALEN]; u8 mld_mac_addr[ETH_ALEN]; + u8 link_addr[MAX_NUM_MLO_LINKS][ETH_ALEN]; enum { STATE1 /* not authenticated */, STATE2 /* authenticated */, @@ -297,6 +298,8 @@ void pmk_deinit(struct wlantest_pmk *pmk); void tdls_deinit(struct wlantest_tdls *tdls); struct wlantest_sta * sta_find(struct wlantest_bss *bss, const u8 *addr); +struct wlantest_sta * sta_find_mlo(struct wlantest *wt, + struct wlantest_bss *bss, const u8 *addr); struct wlantest_sta * sta_get(struct wlantest_bss *bss, const u8 *addr); void sta_deinit(struct wlantest_sta *sta); void sta_update_assoc(struct wlantest_sta *sta,