diff --git a/wlantest/rx_eapol.c b/wlantest/rx_eapol.c index 6607aeec4..70ab2cb2a 100644 --- a/wlantest/rx_eapol.c +++ b/wlantest/rx_eapol.c @@ -248,14 +248,7 @@ static int try_pmk(struct wlantest *wt, struct wlantest_bss *bss, sta->tptk_set = 1; return 0; } - add_note(wt, MSG_DEBUG, "Derived new PTK"); - os_memcpy(&sta->ptk, &ptk, sizeof(ptk)); - wpa_hexdump(MSG_DEBUG, "PTK:KCK", sta->ptk.kck, sta->ptk.kck_len); - wpa_hexdump(MSG_DEBUG, "PTK:KEK", sta->ptk.kek, sta->ptk.kek_len); - wpa_hexdump(MSG_DEBUG, "PTK:TK", sta->ptk.tk, sta->ptk.tk_len); - sta->ptk_set = 1; - os_memset(sta->rsc_tods, 0, sizeof(sta->rsc_tods)); - os_memset(sta->rsc_fromds, 0, sizeof(sta->rsc_fromds)); + sta_new_ptk(wt, sta, &ptk); return 0; } diff --git a/wlantest/rx_mgmt.c b/wlantest/rx_mgmt.c index c649d4de5..2a32510f1 100644 --- a/wlantest/rx_mgmt.c +++ b/wlantest/rx_mgmt.c @@ -298,11 +298,7 @@ static void process_ft_auth(struct wlantest *wt, struct wlantest_bss *bss, sta->pairwise_cipher, 0) < 0) return; - add_note(wt, MSG_DEBUG, "Derived new PTK"); - os_memcpy(&sta->ptk, &ptk, sizeof(ptk)); - sta->ptk_set = 1; - os_memset(sta->rsc_tods, 0, sizeof(sta->rsc_tods)); - os_memset(sta->rsc_fromds, 0, sizeof(sta->rsc_fromds)); + sta_new_ptk(wt, sta, &ptk); } @@ -1823,11 +1819,7 @@ static void rx_mgmt_action_ft_response(struct wlantest *wt, 0) < 0) return; - add_note(wt, MSG_DEBUG, "Derived new PTK"); - os_memcpy(&new_sta->ptk, &ptk, sizeof(ptk)); - new_sta->ptk_set = 1; - os_memset(new_sta->rsc_tods, 0, sizeof(new_sta->rsc_tods)); - os_memset(new_sta->rsc_fromds, 0, sizeof(new_sta->rsc_fromds)); + sta_new_ptk(wt, new_sta, &ptk); os_memcpy(new_sta->snonce, parse.fte_snonce, WPA_NONCE_LEN); os_memcpy(new_sta->anonce, parse.fte_anonce, WPA_NONCE_LEN); } diff --git a/wlantest/sta.c b/wlantest/sta.c index c390e0022..571ca1284 100644 --- a/wlantest/sta.c +++ b/wlantest/sta.c @@ -272,3 +272,62 @@ skip_rsn_wpa: sta->rsn_capab & WPA_CAPABILITY_EXT_KEY_ID_FOR_UNICAST ? "ExtKeyID " : ""); } + + +static void sta_copy_ptk(struct wlantest_sta *sta, struct wpa_ptk *ptk) +{ + os_memcpy(&sta->ptk, ptk, sizeof(*ptk)); + sta->ptk_set = 1; + os_memset(sta->rsc_tods, 0, sizeof(sta->rsc_tods)); + os_memset(sta->rsc_fromds, 0, sizeof(sta->rsc_fromds)); +} + + +void sta_new_ptk(struct wlantest *wt, struct wlantest_sta *sta, + struct wpa_ptk *ptk) +{ + struct wlantest_bss *bss; + struct wlantest_sta *osta; + + add_note(wt, MSG_DEBUG, "Derived new PTK"); + sta_copy_ptk(sta, ptk); + wpa_hexdump(MSG_DEBUG, "PTK:KCK", sta->ptk.kck, sta->ptk.kck_len); + wpa_hexdump(MSG_DEBUG, "PTK:KEK", sta->ptk.kek, sta->ptk.kek_len); + wpa_hexdump(MSG_DEBUG, "PTK:TK", sta->ptk.tk, sta->ptk.tk_len); + + dl_list_for_each(bss, &wt->bss, struct wlantest_bss, list) { + dl_list_for_each(osta, &bss->sta, struct wlantest_sta, list) { + bool match = false; + int link_id; + + if (osta == sta) + continue; + if (os_memcmp(sta->addr, osta->addr, ETH_ALEN) == 0) + match = true; + for (link_id = 0; !match && link_id < MAX_NUM_MLO_LINKS; + link_id++) { + if (os_memcmp(osta->link_addr[link_id], + sta->addr, ETH_ALEN) == 0) + match = true; + } + + if (!match) + continue; + wpa_printf(MSG_DEBUG, + "Add PTK to another MLO STA entry " MACSTR + " (MLD " MACSTR " --> " MACSTR ") in BSS " + MACSTR " (MLD " MACSTR " --> " MACSTR ")", + MAC2STR(osta->addr), + MAC2STR(osta->mld_mac_addr), + MAC2STR(sta->mld_mac_addr), + MAC2STR(bss->bssid), + MAC2STR(bss->mld_mac_addr), + MAC2STR(sta->bss->mld_mac_addr)); + sta_copy_ptk(osta, ptk); + os_memcpy(osta->mld_mac_addr, sta->mld_mac_addr, + ETH_ALEN); + os_memcpy(osta->bss->mld_mac_addr, + sta->bss->mld_mac_addr, ETH_ALEN); + } + } +} diff --git a/wlantest/wlantest.h b/wlantest/wlantest.h index 6627701f4..20f411783 100644 --- a/wlantest/wlantest.h +++ b/wlantest/wlantest.h @@ -306,6 +306,8 @@ struct wlantest_sta * sta_get(struct wlantest_bss *bss, const u8 *addr); void sta_deinit(struct wlantest_sta *sta); void sta_update_assoc(struct wlantest_sta *sta, struct ieee802_11_elems *elems); +void sta_new_ptk(struct wlantest *wt, struct wlantest_sta *sta, + struct wpa_ptk *ptk); u8 * ccmp_decrypt(const u8 *tk, const struct ieee80211_hdr *hdr, const u8 *a1, const u8 *a2, const u8 *data, size_t data_len,