diff --git a/tests/hwsim/test_rsn_override.py b/tests/hwsim/test_rsn_override.py index 6733102e7..62419f2c9 100644 --- a/tests/hwsim/test_rsn_override.py +++ b/tests/hwsim/test_rsn_override.py @@ -122,7 +122,11 @@ def test_rsn_override_mld_mixed(dev, apdev): """AP MLD and RSNE=WPA2-Personal/PMF-disabled override=WPA3-Personal/PMF-required on one link""" run_rsn_override_mld(dev, apdev, True) -def run_rsn_override_mld(dev, apdev, mixed): +def test_rsn_override_mld_only_sta(dev, apdev): + """AP MLD and RSN overriding only on STA""" + run_rsn_override_mld(dev, apdev, False, only_sta=True) + +def run_rsn_override_mld(dev, apdev, mixed, only_sta=False): with HWSimRadio(use_mlo=True) as (hapd_radio, hapd_iface), \ HWSimRadio(use_mlo=True) as (wpas_radio, wpas_iface): @@ -141,7 +145,11 @@ def run_rsn_override_mld(dev, apdev, mixed): params['sae_groups'] = '19 20' params['sae_require_mfp'] = '1' params['sae_pwe'] = '2' - if not mixed: + if only_sta: + params['wpa_key_mgmt'] = 'SAE SAE-EXT-KEY' + params['rsn_pairwise'] = 'CCMP GCMP-256' + params['ieee80211w'] = '2' + elif not mixed: params['rsn_override_key_mgmt'] = 'SAE' params['rsn_override_key_mgmt_2'] = 'SAE-EXT-KEY' params['rsn_override_pairwise'] = 'CCMP' @@ -179,6 +187,8 @@ def run_rsn_override_mld(dev, apdev, mixed): eht_verify_wifi_version(wpas) traffic_test(wpas, hapd0) traffic_test(wpas, hapd1) + if only_sta: + return dev[0].set("rsn_overriding", "0") dev[0].connect(ssid, psk=passphrase, key_mgmt="WPA-PSK", @@ -320,3 +330,19 @@ def test_rsn_override_rsnxe_extensibility(dev, apdev): finally: dev[0].set("sae_pwe", "0") dev[0].set("rsn_overriding", "0") + +def test_rsn_override_sta_only(dev, apdev): + """RSN overriding enabled only on the STA""" + check_sae_capab(dev[0]) + params = hostapd.wpa2_params(ssid="test-sae", + passphrase="12345678") + params['wpa_key_mgmt'] = 'SAE' + hapd = hostapd.add_ap(apdev[0], params) + + dev[0].set("sae_groups", "") + try: + dev[0].set("rsn_overriding", "1") + dev[0].connect("test-sae", psk="12345678", key_mgmt="SAE", + scan_freq="2412") + finally: + dev[0].set("rsn_overriding", "0")