diff --git a/src/common/sae.c b/src/common/sae.c
index 057e1ce3b..372905db0 100644
--- a/src/common/sae.c
+++ b/src/common/sae.c
@@ -1609,18 +1609,26 @@ static int sae_derive_keys(struct sae_data *sae, const u8 *k)
 	 * octets). */
 	crypto_bignum_to_bin(tmp, val, sizeof(val), sae->tmp->order_len);
 	wpa_hexdump(MSG_DEBUG, "SAE: PMKID", val, SAE_PMKID_LEN);
-	if (!sae->pk &&
-	    sae_kdf_hash(hash_len, keyseed, "SAE KCK and PMK",
+
+#ifdef CONFIG_SAE_PK
+	if (sae->pk) {
+		if (sae_kdf_hash(hash_len, keyseed, "SAE-PK keys",
+				 val, sae->tmp->order_len,
+				 keys, 2 * hash_len + SAE_PMK_LEN) < 0)
+			goto fail;
+	} else {
+		if (sae_kdf_hash(hash_len, keyseed, "SAE KCK and PMK",
+				 val, sae->tmp->order_len,
+				 keys, hash_len + SAE_PMK_LEN) < 0)
+			goto fail;
+	}
+#else /* CONFIG_SAE_PK */
+	if (sae_kdf_hash(hash_len, keyseed, "SAE KCK and PMK",
 			 val, sae->tmp->order_len,
 			 keys, hash_len + SAE_PMK_LEN) < 0)
 		goto fail;
-#ifdef CONFIG_SAE_PK
-	if (sae->pk &&
-	    sae_kdf_hash(hash_len, keyseed, "SAE-PK keys",
-			 val, sae->tmp->order_len,
-			 keys, 2 * hash_len + SAE_PMK_LEN) < 0)
-		goto fail;
-#endif /* CONFIG_SAE_PK */
+#endif /* !CONFIG_SAE_PK */
+
 	forced_memzero(keyseed, sizeof(keyseed));
 	os_memcpy(sae->tmp->kck, keys, hash_len);
 	sae->tmp->kck_len = hash_len;