diff --git a/src/ap/drv_callbacks.c b/src/ap/drv_callbacks.c index 210068a94..9878a79df 100644 --- a/src/ap/drv_callbacks.c +++ b/src/ap/drv_callbacks.c @@ -1811,14 +1811,15 @@ static int hostapd_event_new_sta(struct hostapd_data *hapd, const u8 *addr) static struct hostapd_data * hostapd_find_by_sta(struct hostapd_iface *iface, - const u8 *src) + const u8 *src, bool rsn) { struct sta_info *sta; unsigned int j; for (j = 0; j < iface->num_bss; j++) { sta = ap_get_sta(iface->bss[j], src); - if (sta && sta->flags & WLAN_STA_ASSOC) + if (sta && (sta->flags & WLAN_STA_ASSOC) && + (!rsn || sta->wpa_sm)) return iface->bss[j]; } @@ -1826,6 +1827,36 @@ static struct hostapd_data * hostapd_find_by_sta(struct hostapd_iface *iface, } +#ifdef CONFIG_IEEE80211BE +static bool search_mld_sta(struct hostapd_data **p_hapd, const u8 *src, + bool rsn) +{ + struct hostapd_data *hapd = *p_hapd; + unsigned int i; + + /* Search for STA on other MLO BSSs */ + for (i = 0; i < hapd->iface->interfaces->count; i++) { + struct hostapd_iface *h = + hapd->iface->interfaces->iface[i]; + struct hostapd_data *h_hapd = h->bss[0]; + struct hostapd_bss_config *hconf = h_hapd->conf; + + if (!hconf->mld_ap || + hconf->mld_id != hapd->conf->mld_id) + continue; + + h_hapd = hostapd_find_by_sta(h, src, true); + if (h_hapd) { + *p_hapd = h_hapd; + return true; + } + } + + return false; +} +#endif /* CONFIG_IEEE80211BE */ + + static void hostapd_event_eapol_rx(struct hostapd_data *hapd, const u8 *src, const u8 *data, size_t data_len, enum frame_encryption encrypted, @@ -1838,36 +1869,28 @@ static void hostapd_event_eapol_rx(struct hostapd_data *hapd, const u8 *src, struct hostapd_data *h_hapd; hapd = switch_link_hapd(hapd, link_id); - h_hapd = hostapd_find_by_sta(hapd->iface, src); + h_hapd = hostapd_find_by_sta(hapd->iface, src, true); if (!h_hapd) - h_hapd = hostapd_find_by_sta(orig_hapd->iface, src); + h_hapd = hostapd_find_by_sta(orig_hapd->iface, src, + true); + if (!h_hapd) + h_hapd = hostapd_find_by_sta(hapd->iface, src, false); + if (!h_hapd) + h_hapd = hostapd_find_by_sta(orig_hapd->iface, src, + false); if (h_hapd) hapd = h_hapd; } else if (hapd->conf->mld_ap) { - unsigned int i; + bool found; - /* Search for STA on other MLO BSSs */ - for (i = 0; i < hapd->iface->interfaces->count; i++) { - struct hostapd_iface *h = - hapd->iface->interfaces->iface[i]; - struct hostapd_data *h_hapd = h->bss[0]; - struct hostapd_bss_config *hconf = h_hapd->conf; - - if (!hconf->mld_ap || - hconf->mld_id != hapd->conf->mld_id) - continue; - - h_hapd = hostapd_find_by_sta(h, src); - if (h_hapd) { - hapd = h_hapd; - break; - } - } + found = search_mld_sta(&hapd, src, true); + if (!found) + search_mld_sta(&hapd, src, false); } else { - hapd = hostapd_find_by_sta(hapd->iface, src); + hapd = hostapd_find_by_sta(hapd->iface, src, false); } #else /* CONFIG_IEEE80211BE */ - hapd = hostapd_find_by_sta(hapd->iface, src); + hapd = hostapd_find_by_sta(hapd->iface, src, false); #endif /* CONFIG_IEEE80211BE */ if (!hapd) {