diff mbox series

[09/22] hostapd: MLO: pass link_id in get_hapd_bssid helper function

Message ID 20240328181652.2956122-10-quic_adisi@quicinc.com
State Accepted
Headers show
Series [01/22] hostapd: MLO: fix for_each_mld_link macro | expand

Commit Message

Aditya Kumar Singh March 28, 2024, 6:16 p.m. UTC
From: Sriram R <quic_srirrama@quicinc.com>

Currently get_hapd_bssid() function matches the given bssid in all bsses
of its own iface. However with MLO, there is requirement to check its
own partner BSS at least.

Make changes to compare its link partners as well and if link id passed
matches with the link id of the partner then return it.

Signed-off-by: Sriram R <quic_srirrama@quicinc.com>
Signed-off-by: Aditya Kumar Singh <quic_adisi@quicinc.com>
---
 src/ap/drv_callbacks.c | 47 +++++++++++++++++++++++++-----------------
 1 file changed, 28 insertions(+), 19 deletions(-)
diff mbox series

Patch

diff --git a/src/ap/drv_callbacks.c b/src/ap/drv_callbacks.c
index 2d32069091a9..adac2d478c2a 100644
--- a/src/ap/drv_callbacks.c
+++ b/src/ap/drv_callbacks.c
@@ -1750,7 +1750,7 @@  switch_link_hapd(struct hostapd_data *hapd, int link_id)
 #define HAPD_BROADCAST ((struct hostapd_data *) -1)
 
 static struct hostapd_data * get_hapd_bssid(struct hostapd_iface *iface,
-					    const u8 *bssid)
+					    const u8 *bssid, int link_id)
 {
 	size_t i;
 
@@ -1761,8 +1761,30 @@  static struct hostapd_data * get_hapd_bssid(struct hostapd_iface *iface,
 		return HAPD_BROADCAST;
 
 	for (i = 0; i < iface->num_bss; i++) {
+#ifdef CONFIG_IEEE80211BE
+		struct hostapd_data *hapd, *p_hapd;
+
+		hapd = iface->bss[i];
+		if (ether_addr_equal(bssid, hapd->own_addr) ||
+		    (hapd->conf->mld_ap &&
+		     ether_addr_equal(bssid, hapd->mld->mld_addr) &&
+		     link_id == hapd->mld_link_id)) {
+			return hapd;
+		} else if (hapd->conf->mld_ap) {
+			for_each_mld_link(p_hapd, hapd) {
+				if (p_hapd == hapd)
+					continue;
+
+				if (ether_addr_equal(bssid, p_hapd->own_addr) ||
+				    (ether_addr_equal(bssid, p_hapd->mld->mld_addr) &&
+				     link_id == p_hapd->mld_link_id))
+					return p_hapd;
+			}
+		}
+#else
 		if (ether_addr_equal(bssid, iface->bss[i]->own_addr))
 			return iface->bss[i];
+#endif /*CONFIG_IEEE80211BE */
 	}
 
 	return NULL;
@@ -1773,7 +1795,7 @@  static void hostapd_rx_from_unknown_sta(struct hostapd_data *hapd,
 					const u8 *bssid, const u8 *addr,
 					int wds)
 {
-	hapd = get_hapd_bssid(hapd->iface, bssid);
+	hapd = get_hapd_bssid(hapd->iface, bssid, -1);
 	if (hapd == NULL || hapd == HAPD_BROADCAST)
 		return;
 
@@ -1813,14 +1835,7 @@  static int hostapd_mgmt_rx(struct hostapd_data *hapd, struct rx_mgmt *rx_mgmt)
 	if (bssid == NULL)
 		return 0;
 
-#ifdef CONFIG_IEEE80211BE
-	if (hapd->conf->mld_ap &&
-	    ether_addr_equal(hapd->mld->mld_addr, bssid))
-		is_mld = true;
-#endif /* CONFIG_IEEE80211BE */
-
-	if (!is_mld)
-		hapd = get_hapd_bssid(iface, bssid);
+	hapd = get_hapd_bssid(iface, bssid, rx_mgmt->link_id);
 
 	if (!hapd) {
 		u16 fc = le_to_host16(hdr->frame_control);
@@ -1872,17 +1887,11 @@  static void hostapd_mgmt_tx_cb(struct hostapd_data *hapd, const u8 *buf,
 	struct ieee80211_hdr *hdr;
 	struct hostapd_data *orig_hapd, *tmp_hapd;
 
-#ifdef CONFIG_IEEE80211BE
-	if (hapd->conf->mld_ap && link_id != -1) {
-		tmp_hapd = hostapd_mld_get_link_bss(hapd, link_id);
-		if (tmp_hapd)
-			hapd = tmp_hapd;
-	}
-#endif /* CONFIG_IEEE80211BE */
 	orig_hapd = hapd;
 
 	hdr = (struct ieee80211_hdr *) buf;
-	tmp_hapd = get_hapd_bssid(hapd->iface, get_hdr_bssid(hdr, len));
+	hapd = switch_link_hapd(hapd, link_id);
+	tmp_hapd = get_hapd_bssid(hapd->iface, get_hdr_bssid(hdr, len), link_id);
 	if (tmp_hapd) {
 		hapd = tmp_hapd;
 #ifdef CONFIG_IEEE80211BE
@@ -1899,7 +1908,7 @@  static void hostapd_mgmt_tx_cb(struct hostapd_data *hapd, const u8 *buf,
 		if (stype != WLAN_FC_STYPE_ACTION || len <= 25 ||
 		    buf[24] != WLAN_ACTION_PUBLIC)
 			return;
-		hapd = get_hapd_bssid(orig_hapd->iface, hdr->addr2);
+		hapd = get_hapd_bssid(orig_hapd->iface, hdr->addr2, link_id);
 		if (!hapd || hapd == HAPD_BROADCAST)
 			return;
 		/*