diff mbox series

[16/22] hostapd: MLO: Enhance wpa state machine

Message ID 20240328181652.2956122-17-quic_adisi@quicinc.com
State Accepted
Headers show
Series [01/22] hostapd: MLO: fix for_each_mld_link macro | expand

Commit Message

Aditya Kumar Singh March 28, 2024, 6:16 p.m. UTC
From: Rameshkumar Sundaram <quic_ramess@quicinc.com>

Add required ML Specific members in wpa_authenticator and struct
wpa_state_machine to maintain self and partner link information.

Maintain state machine object in all associated link stations and
destroy/remove references from the same whenever link stations are getting
removed.

Increase the wpa_group object reference count for all links in which ML
station is getting associated and release the same whenever link stations
are getting removed.

Signed-off-by: Rameshkumar Sundaram <quic_ramess@quicinc.com>
Signed-off-by: Aditya Kumar Singh <quic_adisi@quicinc.com>
---
 src/ap/ieee802_11.c |   9 ++--
 src/ap/sta_info.c   |  35 ++++++++++++++-
 src/ap/wpa_auth.c   | 101 +++++++++++++++++++++++++++++++++++++++++---
 src/ap/wpa_auth.h   |  16 +++++++
 src/ap/wpa_auth_i.h |   8 ++++
 5 files changed, 159 insertions(+), 10 deletions(-)
diff mbox series

Patch

diff --git a/src/ap/ieee802_11.c b/src/ap/ieee802_11.c
index 39c63f29bba7..9d04bdf43919 100644
--- a/src/ap/ieee802_11.c
+++ b/src/ap/ieee802_11.c
@@ -4467,6 +4467,8 @@  static int ieee80211_ml_process_link(struct hostapd_data *hapd,
 	}
 
 	sta->flags |= origin_sta->flags | WLAN_STA_ASSOC_REQ_OK;
+	sta->mld_assoc_link_id = origin_sta->mld_assoc_link_id;
+
 	status = __check_assoc_ies(hapd, sta, NULL, 0, &elems, reassoc, true);
 	if (status != WLAN_STATUS_SUCCESS) {
 		wpa_printf(MSG_DEBUG, "MLD: link: Element check failed");
@@ -4474,7 +4476,6 @@  static int ieee80211_ml_process_link(struct hostapd_data *hapd,
 	}
 
 	ap_sta_set_mld(sta, true);
-	sta->mld_assoc_link_id = origin_sta->mld_assoc_link_id;
 
 	os_memcpy(&sta->mld_info, &origin_sta->mld_info, sizeof(sta->mld_info));
 	for (i = 0; i < MAX_NUM_MLD_LINKS; i++) {
@@ -4501,9 +4502,11 @@  static int ieee80211_ml_process_link(struct hostapd_data *hapd,
 			ieee802_11_update_beacons(hapd->iface);
 	}
 
-	/* RSN Authenticator should always be the one on the original station */
+	/* Maintain state machine reference on all link STAs, this is needed
+	 * during Group rekey handling.
+	 */
 	wpa_auth_sta_deinit(sta->wpa_sm);
-	sta->wpa_sm = NULL;
+	sta->wpa_sm = origin_sta->wpa_sm;
 
 	/*
 	 * Do not initialize the EAPOL state machine.
diff --git a/src/ap/sta_info.c b/src/ap/sta_info.c
index 2423ff1896e4..d483aa9d3abf 100644
--- a/src/ap/sta_info.c
+++ b/src/ap/sta_info.c
@@ -199,6 +199,26 @@  static void __ap_free_sta(struct hostapd_data *hapd, struct sta_info *sta)
 	hostapd_drv_sta_remove(hapd, sta->addr);
 }
 
+#ifdef CONFIG_IEEE80211BE
+static void set_for_each_partner_link_sta(struct hostapd_data *hapd,
+					  struct sta_info *psta, void *data)
+{
+	struct sta_info *lsta;
+	struct hostapd_data *lhapd;
+
+	if (!ap_sta_is_mld(hapd, psta))
+		return;
+
+	for_each_mld_link(lhapd, hapd) {
+		if (lhapd == hapd)
+			continue;
+
+		lsta = ap_get_sta(lhapd, psta->addr);
+		if (lsta)
+			lsta->wpa_sm = data;
+	}
+}
+#endif /* CONFIG_IEEE80211BE */
 
 void ap_free_sta(struct hostapd_data *hapd, struct sta_info *sta)
 {
@@ -317,8 +337,17 @@  void ap_free_sta(struct hostapd_data *hapd, struct sta_info *sta)
 
 #ifdef CONFIG_IEEE80211BE
 	if (!ap_sta_is_mld(hapd, sta) ||
-	    hapd->mld_link_id == sta->mld_assoc_link_id)
+	    hapd->mld_link_id == sta->mld_assoc_link_id) {
 		wpa_auth_sta_deinit(sta->wpa_sm);
+		/* Remove refrences from partner links. */
+		set_for_each_partner_link_sta(hapd, sta, NULL);
+	}
+
+	/* Release group references in case non assoc link STA is removed
+	 * before assoc link STA
+	 */
+	if (hostapd_sta_is_link_sta(hapd, sta))
+		wpa_release_link_auth_ref(sta->wpa_sm, hapd->mld_link_id);
 #else
 	wpa_auth_sta_deinit(sta->wpa_sm);
 #endif /* CONFIG_IEEE80211BE */
@@ -903,8 +932,10 @@  static void ap_sta_disconnect_common(struct hostapd_data *hapd,
 	ieee802_1x_free_station(hapd, sta);
 #ifdef CONFIG_IEEE80211BE
 	if (!hapd->conf->mld_ap ||
-	    hapd->mld_link_id == sta->mld_assoc_link_id)
+	    hapd->mld_link_id == sta->mld_assoc_link_id) {
 		wpa_auth_sta_deinit(sta->wpa_sm);
+		set_for_each_partner_link_sta(hapd, sta, NULL);
+	}
 #else
 	wpa_auth_sta_deinit(sta->wpa_sm);
 #endif /* CONFIG_IEEE80211BE */
diff --git a/src/ap/wpa_auth.c b/src/ap/wpa_auth.c
index 0d15c42099dc..8c1052c25ca7 100644
--- a/src/ap/wpa_auth.c
+++ b/src/ap/wpa_auth.c
@@ -102,6 +102,59 @@  static const u8 * wpa_auth_get_spa(const struct wpa_state_machine *sm)
 	return sm->addr;
 }
 
+#ifdef CONFIG_IEEE80211BE
+void wpa_release_link_auth_ref(struct wpa_state_machine *sm, int release_link_id)
+{
+	int link_id;
+
+	if (!sm || release_link_id >= MAX_NUM_MLD_LINKS)
+		return;
+
+	for_each_sm_auth(sm, link_id)
+		if (link_id == release_link_id) {
+			wpa_group_put(sm->mld_links[link_id].wpa_auth,
+				      sm->mld_links[link_id].wpa_auth->group);
+			sm->mld_links[link_id].wpa_auth = NULL;
+		}
+}
+
+static int wpa_get_link_sta_auth(struct wpa_authenticator *wpa_auth, void *data)
+{
+	struct wpa_get_link_auth_ctx *ctx = data;
+
+	if (os_memcmp(wpa_auth->addr, ctx->addr, ETH_ALEN) != 0)
+		return 0;
+	ctx->wpa_auth = wpa_auth;
+	return 1;
+}
+
+static int wpa_get_primary_wpa_auth_cb(struct wpa_authenticator *wpa_auth, void *data)
+{
+	struct wpa_get_link_auth_ctx *ctx = data;
+
+	if (!wpa_auth->is_ml || os_memcmp(wpa_auth->mld_addr, ctx->addr, ETH_ALEN) != 0 ||
+	    !wpa_auth->primary_auth)
+		return 0;
+
+	ctx->wpa_auth = wpa_auth;
+	return 1;
+}
+
+static struct wpa_authenticator *
+wpa_get_primary_wpa_auth(struct wpa_authenticator *wpa_auth)
+{
+	struct wpa_get_link_auth_ctx ctx;
+
+	if (!wpa_auth || !wpa_auth->is_ml || wpa_auth->primary_auth)
+		return wpa_auth;
+
+	ctx.addr = wpa_auth->mld_addr;
+	ctx.wpa_auth = NULL;
+	wpa_auth_for_each_auth(wpa_auth, wpa_get_primary_wpa_auth_cb, &ctx);
+
+	return ctx.wpa_auth;
+}
+#endif /* CONFIG_IEEE80211BE */
 
 static inline int wpa_auth_mic_failure_report(
 	struct wpa_authenticator *wpa_auth, const u8 *addr)
@@ -798,6 +851,10 @@  void wpa_auth_sta_no_wpa(struct wpa_state_machine *sm)
 
 static void wpa_free_sta_sm(struct wpa_state_machine *sm)
 {
+#ifdef CONFIG_IEEE80211BE
+	int link_id;
+#endif /* CONFIG_IEEE80211BE */
+
 #ifdef CONFIG_P2P
 	if (WPA_GET_BE32(sm->ip_addr)) {
 		wpa_printf(MSG_DEBUG,
@@ -821,6 +878,13 @@  static void wpa_free_sta_sm(struct wpa_state_machine *sm)
 	os_free(sm->last_rx_eapol_key);
 	os_free(sm->wpa_ie);
 	os_free(sm->rsnxe);
+#ifdef CONFIG_IEEE80211BE
+	for_each_sm_auth(sm, link_id) {
+		wpa_group_put(sm->mld_links[link_id].wpa_auth,
+			      sm->mld_links[link_id].wpa_auth->group);
+		sm->mld_links[link_id].wpa_auth = NULL;
+	}
+#endif /* CONFIG_IEEE80211BE */
 	wpa_group_put(sm->wpa_auth, sm->group);
 #ifdef CONFIG_DPP2
 	wpabuf_clear_free(sm->dpp_z);
@@ -831,7 +895,7 @@  static void wpa_free_sta_sm(struct wpa_state_machine *sm)
 
 void wpa_auth_sta_deinit(struct wpa_state_machine *sm)
 {
-	struct wpa_authenticator *wpa_auth;
+	struct wpa_authenticator *wpa_auth, *primary_wpa_auth;
 
 	if (!sm)
 		return;
@@ -840,10 +904,18 @@  void wpa_auth_sta_deinit(struct wpa_state_machine *sm)
 	if (wpa_auth->conf.wpa_strict_rekey && sm->has_GTK) {
 		wpa_auth_logger(wpa_auth, wpa_auth_get_spa(sm), LOGGER_DEBUG,
 				"strict rekeying - force GTK rekey since STA is leaving");
+
+#ifdef CONFIG_IEEE80211BE
+		if (wpa_auth->is_ml && !wpa_auth->primary_auth)
+			primary_wpa_auth = wpa_get_primary_wpa_auth(wpa_auth);
+		else
+#endif /* CONFIG_IEEE80211BE */
+			primary_wpa_auth = wpa_auth;
+
 		if (eloop_deplete_timeout(0, 500000, wpa_rekey_gtk,
-					  wpa_auth, NULL) == -1)
+					  primary_wpa_auth, NULL) == -1)
 			eloop_register_timeout(0, 500000, wpa_rekey_gtk,
-					       wpa_auth, NULL);
+					       primary_wpa_auth, NULL);
 	}
 
 	eloop_cancel_timeout(wpa_send_eapol_timeout, wpa_auth, sm);
@@ -6835,6 +6907,7 @@  void wpa_auth_set_ml_info(struct wpa_state_machine *sm, const u8 *mld_addr,
 	for (i = 0, link_id = 0; link_id < MAX_NUM_MLD_LINKS; link_id++) {
 		struct mld_link_info *link = &info->links[link_id];
 		struct mld_link *sm_link = &sm->mld_links[link_id];
+		struct wpa_get_link_auth_ctx ctx;
 
 		sm_link->valid = link->valid;
 		if (!link->valid)
@@ -6849,10 +6922,28 @@  void wpa_auth_set_ml_info(struct wpa_state_machine *sm, const u8 *mld_addr,
 			   MAC2STR(sm_link->own_addr),
 			   MAC2STR(sm_link->peer_addr));
 
-		if (link_id != mld_assoc_link_id)
+		ml_rsn_info.links[i++].link_id = link_id;
+
+		if (link_id != mld_assoc_link_id) {
 			sm->n_mld_affiliated_links++;
+			ctx.addr = link->local_addr;
+			ctx.wpa_auth = NULL;
+			wpa_auth_for_each_auth(sm->wpa_auth, wpa_get_link_sta_auth, &ctx);
+
+			if (ctx.wpa_auth) {
+				sm_link->wpa_auth = ctx.wpa_auth;
+				wpa_group_get(sm_link->wpa_auth,
+					      sm_link->wpa_auth->group);
+			}
+		} else {
+			sm_link->wpa_auth = sm->wpa_auth;
+		}
 
-		ml_rsn_info.links[i++].link_id = link_id;
+		if (!sm_link->wpa_auth)
+			wpa_printf(MSG_ERROR, "Unable to find authenticator object for "
+				    "ML STA " MACSTR " on link " MACSTR " link id %d",
+				    MAC2STR(sm->own_mld_addr), MAC2STR(sm_link->own_addr),
+				    link_id);
 	}
 
 	ml_rsn_info.n_mld_links = i;
diff --git a/src/ap/wpa_auth.h b/src/ap/wpa_auth.h
index c74862307784..1446872f3c10 100644
--- a/src/ap/wpa_auth.h
+++ b/src/ap/wpa_auth.h
@@ -647,4 +647,20 @@  void wpa_auth_ml_get_key_info(struct wpa_authenticator *a,
 			      struct wpa_auth_ml_link_key_info *info,
 			      bool mgmt_frame_prot, bool beacon_prot);
 
+#ifdef CONFIG_IEEE80211BE
+void wpa_release_link_auth_ref(struct wpa_state_machine *sm,
+			       int release_link_id);
+
+#define for_each_sm_auth(sm, link_id) \
+	for (link_id = 0; link_id < MAX_NUM_MLD_LINKS; link_id++) \
+		if (sm->mld_links[link_id].valid && \
+		    sm->mld_links[link_id].wpa_auth && \
+		    sm->wpa_auth != sm->mld_links[link_id].wpa_auth) \
+
+struct wpa_get_link_auth_ctx {
+	u8 *addr;
+	struct wpa_authenticator *wpa_auth;
+};
+#endif /* CONFIG_IEEE80211BE */
+
 #endif /* WPA_AUTH_H */
diff --git a/src/ap/wpa_auth_i.h b/src/ap/wpa_auth_i.h
index 9ba8304150ce..9ba90749da87 100644
--- a/src/ap/wpa_auth_i.h
+++ b/src/ap/wpa_auth_i.h
@@ -186,6 +186,7 @@  struct wpa_state_machine {
 		size_t rsne_len;
 		const u8 *rsnxe;
 		size_t rsnxe_len;
+		struct wpa_authenticator *wpa_auth;
 	} mld_links[MAX_NUM_MLD_LINKS];
 #endif /* CONFIG_IEEE80211BE */
 };
@@ -262,6 +263,13 @@  struct wpa_authenticator {
 #ifdef CONFIG_P2P
 	struct bitfield *ip_pool;
 #endif /* CONFIG_P2P */
+
+#ifdef CONFIG_IEEE80211BE
+	bool is_ml;
+	u8 mld_addr[ETH_ALEN];
+	u8 link_id;
+	bool primary_auth;
+#endif /* CONFIG_IEEE80211BE */
 };