diff --git a/tests/hwsim/test_owe.py b/tests/hwsim/test_owe.py index eb5e8a72b..645e8d4dc 100644 --- a/tests/hwsim/test_owe.py +++ b/tests/hwsim/test_owe.py @@ -34,8 +34,12 @@ def test_owe(dev, apdev): if "[WPA2-OWE-CCMP]" not in bss['flags']: raise Exception("OWE AKM not recognized: " + bss['flags']) - dev[0].connect("owe", key_mgmt="OWE", ieee80211w="2", - scan_freq="2412") + id = dev[0].connect("owe", key_mgmt="OWE", ieee80211w="2", scan_freq="2412") + hapd.wait_sta() + pmk_h = hapd.request("GET_PMK " + dev[0].own_addr()) + pmk_w = dev[0].get_pmk(id) + if pmk_h != pmk_w: + raise Exception("Fetched PMK does not match: hostapd %s, wpa_supplicant %s" % (pmk_h, pmk_w)) hwsim_utils.test_connectivity(dev[0], hapd) val = dev[0].get_status_field("key_mgmt") if val != "OWE": diff --git a/tests/hwsim/test_sae.py b/tests/hwsim/test_sae.py index 52d7cf849..d54f0ec95 100644 --- a/tests/hwsim/test_sae.py +++ b/tests/hwsim/test_sae.py @@ -39,6 +39,7 @@ def test_sae(dev, apdev): dev[0].request("SET sae_groups ") id = dev[0].connect("test-sae", psk="12345678", key_mgmt="SAE", scan_freq="2412") + hapd.wait_sta() if dev[0].get_status_field('sae_group') != '19': raise Exception("Expected default SAE group not used") bss = dev[0].get_bss(apdev[0]['bssid']) @@ -51,6 +52,11 @@ def test_sae(dev, apdev): if "sae_group=19" not in res.splitlines(): raise Exception("hostapd STA output did not specify SAE group") + pmk_h = hapd.request("GET_PMK " + dev[0].own_addr()) + pmk_w = dev[0].get_pmk(id) + if pmk_h != pmk_w: + raise Exception("Fetched PMK does not match: hostapd %s, wpa_supplicant %s" % (pmk_h, pmk_w)) + @remote_compatible def test_sae_password_ecc(dev, apdev): """SAE with number of different passwords (ECC)""" diff --git a/tests/hwsim/wpasupplicant.py b/tests/hwsim/wpasupplicant.py index 47bbb3a1a..7ed5d1363 100644 --- a/tests/hwsim/wpasupplicant.py +++ b/tests/hwsim/wpasupplicant.py @@ -1300,6 +1300,14 @@ class WpaSupplicant: return vals return None + def get_pmk(self, network_id): + bssid = self.get_status_field('bssid') + res = self.request("PMKSA_GET %d" % network_id) + for val in res.splitlines(): + if val.startswith(bssid): + return val.split(' ')[2] + return None + def get_sta(self, addr, info=None, next=False): cmd = "STA-NEXT " if next else "STA " if addr is None: