diff --git a/src/rsn_supp/tdls.c b/src/rsn_supp/tdls.c index 10b517778..32d596e4a 100644 --- a/src/rsn_supp/tdls.c +++ b/src/rsn_supp/tdls.c @@ -164,6 +164,14 @@ struct wpa_tdls_peer { }; +static const u8 * wpa_tdls_get_link_bssid(struct wpa_sm *sm, int link_id) +{ + if (link_id >= 0) + return sm->mlo.links[link_id].bssid; + return sm->bssid; +} + + static int wpa_tdls_get_privacy(struct wpa_sm *sm) { /* @@ -756,7 +764,8 @@ static void wpa_tdls_linkid(struct wpa_sm *sm, struct wpa_tdls_peer *peer, { lnkid->ie_type = WLAN_EID_LINK_ID; lnkid->ie_len = 3 * ETH_ALEN; - os_memcpy(lnkid->bssid, sm->bssid, ETH_ALEN); + os_memcpy(lnkid->bssid, wpa_tdls_get_link_bssid(sm, peer->mld_link_id), + ETH_ALEN); if (peer->initiator) { os_memcpy(lnkid->init_sta, sm->own_addr, ETH_ALEN); os_memcpy(lnkid->resp_sta, peer->addr, ETH_ALEN); @@ -1949,6 +1958,7 @@ static int wpa_tdls_process_tpk_m1(struct wpa_sm *sm, const u8 *src_addr, u16 status = WLAN_STATUS_UNSPECIFIED_FAILURE; int tdls_prohibited = sm->tdls_prohibited; int existing_peer = 0; + int link_id = -1; if (len < 3 + 3) return -1; @@ -2024,12 +2034,15 @@ static int wpa_tdls_process_tpk_m1(struct wpa_sm *sm, const u8 *src_addr, wpa_hexdump(MSG_DEBUG, "TDLS: Link ID Received from TPK M1", kde.lnkid, kde.lnkid_len); lnkid = (struct wpa_tdls_lnkid *) kde.lnkid; - if (os_memcmp(sm->bssid, lnkid->bssid, ETH_ALEN) != 0) { - wpa_printf(MSG_INFO, "TDLS: TPK M1 from diff BSS"); + + if (!wpa_tdls_is_lnkid_bss_valid(sm, lnkid, &link_id)) { + wpa_printf(MSG_INFO, "TDLS: TPK M1 from diff BSS " + MACSTR, MAC2STR(lnkid->bssid)); status = WLAN_STATUS_REQUEST_DECLINED; goto error; } + peer->mld_link_id = link_id; wpa_printf(MSG_DEBUG, "TDLS: TPK M1 - TPK initiator " MACSTR, MAC2STR(src_addr)); @@ -2255,7 +2268,11 @@ skip_rsn: peer->lifetime = lifetime; - wpa_tdls_generate_tpk(peer, sm->own_addr, sm->bssid); + if (peer->mld_link_id >= 0) + wpa_printf(MSG_DEBUG, "TDLS: Use link ID %u for TPK derivation", + peer->mld_link_id); + wpa_tdls_generate_tpk(peer, sm->own_addr, + wpa_tdls_get_link_bssid(sm, peer->mld_link_id)); skip_rsn_check: #ifdef CONFIG_TDLS_TESTING @@ -2440,7 +2457,8 @@ static int wpa_tdls_process_tpk_m2(struct wpa_sm *sm, const u8 *src_addr, kde.lnkid, kde.lnkid_len); lnkid = (struct wpa_tdls_lnkid *) kde.lnkid; - if (os_memcmp(sm->bssid, lnkid->bssid, ETH_ALEN) != 0) { + if (os_memcmp(sm->bssid, wpa_tdls_get_link_bssid(sm, peer->mld_link_id), + ETH_ALEN) != 0) { wpa_printf(MSG_INFO, "TDLS: TPK M2 from different BSS"); status = WLAN_STATUS_NOT_IN_SAME_BSS; goto error; @@ -2567,7 +2585,11 @@ static int wpa_tdls_process_tpk_m2(struct wpa_sm *sm, const u8 *src_addr, goto error; } - wpa_tdls_generate_tpk(peer, sm->own_addr, sm->bssid); + if (peer->mld_link_id >= 0) + wpa_printf(MSG_DEBUG, "TDLS: Use link ID %u for TPK derivation", + peer->mld_link_id); + wpa_tdls_generate_tpk(peer, sm->own_addr, + wpa_tdls_get_link_bssid(sm, peer->mld_link_id)); /* Process MIC check to see if TPK M2 is right */ if (wpa_supplicant_verify_tdls_mic(2, peer, (const u8 *) lnkid, @@ -2688,7 +2710,8 @@ static int wpa_tdls_process_tpk_m3(struct wpa_sm *sm, const u8 *src_addr, (u8 *) kde.lnkid, kde.lnkid_len); lnkid = (struct wpa_tdls_lnkid *) kde.lnkid; - if (os_memcmp(sm->bssid, lnkid->bssid, ETH_ALEN) != 0) { + if (os_memcmp(wpa_tdls_get_link_bssid(sm, peer->mld_link_id), + lnkid->bssid, ETH_ALEN) != 0) { wpa_printf(MSG_INFO, "TDLS: TPK M3 from diff BSS"); goto error; }