diff --git a/wlantest/rx_eapol.c b/wlantest/rx_eapol.c index 9f5c6c3b6..0db2587bd 100644 --- a/wlantest/rx_eapol.c +++ b/wlantest/rx_eapol.c @@ -62,9 +62,10 @@ static int check_mic(const u8 *kck, size_t kck_len, int akmp, int ver, static void rx_data_eapol_key_1_of_4(struct wlantest *wt, const u8 *dst, - const u8 *src, const u8 *data, size_t len) + const u8 *src, const u8 *bssid, + const u8 *data, size_t len) { - struct wlantest_bss *bss; + struct wlantest_bss *bss, *bss_mld; struct wlantest_sta *sta; const struct ieee802_1x_hdr *eapol; const struct wpa_eapol_key *hdr; @@ -75,7 +76,16 @@ static void rx_data_eapol_key_1_of_4(struct wlantest *wt, const u8 *dst, wpa_printf(MSG_DEBUG, "EAPOL-Key 1/4 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); - bss = bss_get(wt, src); + if (os_memcmp(src, bssid, ETH_ALEN) == 0) { + bss = bss_get(wt, src); + } else { + bss = bss_find(wt, bssid); + bss_mld = bss_find(wt, src); + if (bss_mld) + bss = bss_get(wt, src); + else + bss = bss_get(wt, bssid); + } if (bss == NULL) return; sta = sta_get(bss, dst); @@ -282,9 +292,10 @@ static void elems_from_eapol_ie(struct ieee802_11_elems *elems, static void rx_data_eapol_key_2_of_4(struct wlantest *wt, const u8 *dst, - const u8 *src, const u8 *data, size_t len) + const u8 *src, const u8 *bssid, + const u8 *data, size_t len) { - struct wlantest_bss *bss; + struct wlantest_bss *bss, *bss_mld; struct wlantest_sta *sta; const struct ieee802_1x_hdr *eapol; const struct wpa_eapol_key *hdr; @@ -295,7 +306,16 @@ static void rx_data_eapol_key_2_of_4(struct wlantest *wt, const u8 *dst, wpa_printf(MSG_DEBUG, "EAPOL-Key 2/4 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); - bss = bss_get(wt, dst); + if (os_memcmp(dst, bssid, ETH_ALEN) == 0) { + bss = bss_get(wt, dst); + } else { + bss = bss_find(wt, bssid); + bss_mld = bss_find(wt, dst); + if (bss_mld) + bss = bss_get(wt, dst); + else + bss = bss_get(wt, bssid); + } if (bss == NULL) return; sta = sta_get(bss, src); @@ -669,9 +689,10 @@ static void learn_kde_keys(struct wlantest *wt, struct wlantest_bss *bss, static void rx_data_eapol_key_3_of_4(struct wlantest *wt, const u8 *dst, - const u8 *src, const u8 *data, size_t len) + const u8 *src, const u8 *bssid, + const u8 *data, size_t len) { - struct wlantest_bss *bss; + struct wlantest_bss *bss, *bss_mld; struct wlantest_sta *sta; const struct ieee802_1x_hdr *eapol; const struct wpa_eapol_key *hdr; @@ -687,7 +708,16 @@ static void rx_data_eapol_key_3_of_4(struct wlantest *wt, const u8 *dst, wpa_printf(MSG_DEBUG, "EAPOL-Key 3/4 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); - bss = bss_get(wt, src); + if (os_memcmp(src, bssid, ETH_ALEN) == 0) { + bss = bss_get(wt, src); + } else { + bss = bss_find(wt, bssid); + bss_mld = bss_find(wt, src); + if (bss_mld) + bss = bss_get(wt, src); + else + bss = bss_get(wt, bssid); + } if (bss == NULL) return; sta = sta_get(bss, dst); @@ -866,9 +896,10 @@ static void rx_data_eapol_key_3_of_4(struct wlantest *wt, const u8 *dst, static void rx_data_eapol_key_4_of_4(struct wlantest *wt, const u8 *dst, - const u8 *src, const u8 *data, size_t len) + const u8 *src, const u8 *bssid, + const u8 *data, size_t len) { - struct wlantest_bss *bss; + struct wlantest_bss *bss, *bss_mld; struct wlantest_sta *sta; const struct ieee802_1x_hdr *eapol; const struct wpa_eapol_key *hdr; @@ -878,7 +909,16 @@ static void rx_data_eapol_key_4_of_4(struct wlantest *wt, const u8 *dst, wpa_printf(MSG_DEBUG, "EAPOL-Key 4/4 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); - bss = bss_get(wt, dst); + if (os_memcmp(dst, bssid, ETH_ALEN) == 0) { + bss = bss_get(wt, dst); + } else { + bss = bss_find(wt, bssid); + bss_mld = bss_find(wt, dst); + if (bss_mld) + bss = bss_get(wt, dst); + else + bss = bss_get(wt, bssid); + } if (bss == NULL) return; sta = sta_get(bss, src); @@ -925,9 +965,10 @@ static void rx_data_eapol_key_4_of_4(struct wlantest *wt, const u8 *dst, static void rx_data_eapol_key_1_of_2(struct wlantest *wt, const u8 *dst, - const u8 *src, const u8 *data, size_t len) + const u8 *src, const u8 *bssid, + const u8 *data, size_t len) { - struct wlantest_bss *bss; + struct wlantest_bss *bss, *bss_mld; struct wlantest_sta *sta; const struct ieee802_1x_hdr *eapol; const struct wpa_eapol_key *hdr; @@ -938,7 +979,16 @@ static void rx_data_eapol_key_1_of_2(struct wlantest *wt, const u8 *dst, wpa_printf(MSG_DEBUG, "EAPOL-Key 1/2 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); - bss = bss_get(wt, src); + if (os_memcmp(src, bssid, ETH_ALEN) == 0) { + bss = bss_get(wt, src); + } else { + bss = bss_find(wt, bssid); + bss_mld = bss_find(wt, src); + if (bss_mld) + bss = bss_get(wt, src); + else + bss = bss_get(wt, bssid); + } if (bss == NULL) return; sta = sta_get(bss, dst); @@ -1053,9 +1103,10 @@ static void rx_data_eapol_key_1_of_2(struct wlantest *wt, const u8 *dst, static void rx_data_eapol_key_2_of_2(struct wlantest *wt, const u8 *dst, - const u8 *src, const u8 *data, size_t len) + const u8 *src, const u8 *bssid, + const u8 *data, size_t len) { - struct wlantest_bss *bss; + struct wlantest_bss *bss, *bss_mld; struct wlantest_sta *sta; const struct ieee802_1x_hdr *eapol; const struct wpa_eapol_key *hdr; @@ -1063,7 +1114,16 @@ static void rx_data_eapol_key_2_of_2(struct wlantest *wt, const u8 *dst, wpa_printf(MSG_DEBUG, "EAPOL-Key 2/2 " MACSTR " -> " MACSTR, MAC2STR(src), MAC2STR(dst)); - bss = bss_get(wt, dst); + if (os_memcmp(dst, bssid, ETH_ALEN) == 0) { + bss = bss_get(wt, dst); + } else { + bss = bss_find(wt, bssid); + bss_mld = bss_find(wt, dst); + if (bss_mld) + bss = bss_get(wt, dst); + else + bss = bss_get(wt, bssid); + } if (bss == NULL) return; sta = sta_get(bss, src); @@ -1254,37 +1314,40 @@ static void rx_data_eapol_key(struct wlantest *wt, const u8 *bssid, WPA_KEY_INFO_ACK | WPA_KEY_INFO_INSTALL)) { case WPA_KEY_INFO_ACK: - rx_data_eapol_key_1_of_4(wt, dst, src, data, len); + rx_data_eapol_key_1_of_4(wt, dst, src, bssid, + data, len); break; case WPA_KEY_INFO_MIC: if (key_data_length == 0 || is_zero(hdr->key_nonce, WPA_NONCE_LEN)) - rx_data_eapol_key_4_of_4(wt, dst, src, data, - len); + rx_data_eapol_key_4_of_4(wt, dst, src, bssid, + data, len); else - rx_data_eapol_key_2_of_4(wt, dst, src, data, - len); + rx_data_eapol_key_2_of_4(wt, dst, src, bssid, + data, len); break; case WPA_KEY_INFO_MIC | WPA_KEY_INFO_ACK | WPA_KEY_INFO_INSTALL: /* WPA does not include Secure bit in 3/4 */ - rx_data_eapol_key_3_of_4(wt, dst, src, data, len); + rx_data_eapol_key_3_of_4(wt, dst, src, bssid, + data, len); break; case WPA_KEY_INFO_SECURE | WPA_KEY_INFO_MIC | WPA_KEY_INFO_ACK | WPA_KEY_INFO_INSTALL: case WPA_KEY_INFO_SECURE | WPA_KEY_INFO_ACK | WPA_KEY_INFO_INSTALL: - rx_data_eapol_key_3_of_4(wt, dst, src, data, len); + rx_data_eapol_key_3_of_4(wt, dst, src, bssid, + data, len); break; case WPA_KEY_INFO_SECURE | WPA_KEY_INFO_MIC: case WPA_KEY_INFO_SECURE: if (key_data_length == 0 || is_zero(hdr->key_nonce, WPA_NONCE_LEN)) - rx_data_eapol_key_4_of_4(wt, dst, src, data, - len); + rx_data_eapol_key_4_of_4(wt, dst, src, bssid, + data, len); else - rx_data_eapol_key_2_of_4(wt, dst, src, data, - len); + rx_data_eapol_key_2_of_4(wt, dst, src, bssid, + data, len); break; default: wpa_printf(MSG_DEBUG, "Unsupported EAPOL-Key frame"); @@ -1298,11 +1361,13 @@ static void rx_data_eapol_key(struct wlantest *wt, const u8 *bssid, case WPA_KEY_INFO_SECURE | WPA_KEY_INFO_MIC | WPA_KEY_INFO_ACK: case WPA_KEY_INFO_SECURE | WPA_KEY_INFO_ACK: - rx_data_eapol_key_1_of_2(wt, dst, src, data, len); + rx_data_eapol_key_1_of_2(wt, dst, src, bssid, + data, len); break; case WPA_KEY_INFO_SECURE | WPA_KEY_INFO_MIC: case WPA_KEY_INFO_SECURE: - rx_data_eapol_key_2_of_2(wt, dst, src, data, len); + rx_data_eapol_key_2_of_2(wt, dst, src, bssid, + data, len); break; default: wpa_printf(MSG_DEBUG, "Unsupported EAPOL-Key frame");