diff --git a/wlantest/rx_eapol.c b/wlantest/rx_eapol.c index ece870215..7e8d28f25 100644 --- a/wlantest/rx_eapol.c +++ b/wlantest/rx_eapol.c @@ -126,6 +126,13 @@ static int try_pmk(struct wlantest *wt, struct wlantest_bss *bss, struct wlantest_pmk *pmk) { struct wpa_ptk ptk; + const u8 *sa, *aa; + bool mlo; + + mlo = !is_zero_ether_addr(sta->mld_mac_addr) && + !is_zero_ether_addr(bss->mld_mac_addr); + sa = mlo ? sta->mld_mac_addr : sta->addr; + aa = mlo ? bss->mld_mac_addr : bss->bssid; if (wpa_key_mgmt_ft(sta->key_mgmt)) { u8 ptk_name[WPA_PMK_NAME_LEN]; @@ -134,19 +141,19 @@ static int try_pmk(struct wlantest *wt, struct wlantest_bss *bss, if (wpa_derive_pmk_r0(pmk->pmk, pmk->pmk_len, bss->ssid, bss->ssid_len, bss->mdid, bss->r0kh_id, bss->r0kh_id_len, - sta->addr, sta->pmk_r0, sta->pmk_r0_name, + sa, sta->pmk_r0, sta->pmk_r0_name, use_sha384) < 0) return -1; sta->pmk_r0_len = use_sha384 ? PMK_LEN_SUITE_B_192 : PMK_LEN; if (wpa_derive_pmk_r1(sta->pmk_r0, sta->pmk_r0_len, sta->pmk_r0_name, - bss->r1kh_id, sta->addr, + bss->r1kh_id, sa, sta->pmk_r1, sta->pmk_r1_name) < 0) return -1; sta->pmk_r1_len = sta->pmk_r0_len; if (wpa_pmk_r1_to_ptk(sta->pmk_r1, sta->pmk_r1_len, - sta->snonce, sta->anonce, sta->addr, - bss->bssid, sta->pmk_r1_name, + sta->snonce, sta->anonce, sa, + aa, sta->pmk_r1_name, &ptk, ptk_name, sta->key_mgmt, sta->pairwise_cipher, 0) < 0 || check_mic(ptk.kck, ptk.kck_len, sta->key_mgmt, ver, data, @@ -154,7 +161,7 @@ static int try_pmk(struct wlantest *wt, struct wlantest_bss *bss, return -1; } else if (wpa_pmk_to_ptk(pmk->pmk, pmk->pmk_len, "Pairwise key expansion", - bss->bssid, sta->addr, sta->anonce, + aa, sa, sta->anonce, sta->snonce, &ptk, sta->key_mgmt, sta->pairwise_cipher, NULL, 0, 0) < 0 || check_mic(ptk.kck, ptk.kck_len, sta->key_mgmt, ver, data, @@ -162,8 +169,16 @@ static int try_pmk(struct wlantest *wt, struct wlantest_bss *bss, return -1; } - wpa_printf(MSG_INFO, "Derived PTK for STA " MACSTR " BSSID " MACSTR, - MAC2STR(sta->addr), MAC2STR(bss->bssid)); + if (mlo) { + wpa_printf(MSG_INFO, "Derived PTK for STA " MACSTR " (MLD " + MACSTR ") BSSID " MACSTR " (MLD " MACSTR ")", + MAC2STR(sta->addr), MAC2STR(sta->mld_mac_addr), + MAC2STR(bss->bssid), MAC2STR(bss->mld_mac_addr)); + } else { + wpa_printf(MSG_INFO, "Derived PTK for STA " MACSTR + " BSSID " MACSTR, + MAC2STR(sta->addr), MAC2STR(bss->bssid)); + } sta->counters[WLANTEST_STA_COUNTER_PTK_LEARNED]++; if (sta->ptk_set) { /* @@ -199,8 +214,9 @@ static void derive_ptk(struct wlantest *wt, struct wlantest_bss *bss, { struct wlantest_pmk *pmk; - wpa_printf(MSG_DEBUG, "Trying to derive PTK for " MACSTR " (ver %u)", - MAC2STR(sta->addr), ver); + wpa_printf(MSG_DEBUG, "Trying to derive PTK for " MACSTR " (MLD " MACSTR + ") (ver %u)", + MAC2STR(sta->addr), MAC2STR(sta->mld_mac_addr), ver); dl_list_for_each(pmk, &bss->pmk, struct wlantest_pmk, list) { wpa_printf(MSG_DEBUG, "Try per-BSS PMK"); if (try_pmk(wt, bss, sta, ver, data, len, pmk) == 0)