@@ -981,6 +981,8 @@ config NETFILTER_XT_MATCH_CGROUP
tristate '"control group" match support'
depends on NETFILTER_ADVANCED
depends on CGROUPS
+ select NF_SOCK_IPV4
+ select NF_SOCK_IPV6 if IP6_NF_IPTABLES
select CGROUP_NET_CLASSID
---help---
Socket/process control group matching allows you to match locally
@@ -16,6 +16,10 @@
#include <linux/module.h>
#include <linux/netfilter/x_tables.h>
#include <linux/netfilter/xt_cgroup.h>
+#include <linux/netfilter_ipv4.h>
+#include <linux/netfilter_ipv6.h>
+#include <net/netfilter/ipv4/nf_defrag_ipv4.h>
+#include <net/netfilter/ipv6/nf_defrag_ipv6.h>
#include <net/sock.h>
MODULE_LICENSE("GPL");
@@ -34,38 +38,93 @@ static int cgroup_mt_check(const struct xt_mtchk_param *par)
return 0;
}
-static bool
-cgroup_mt(const struct sk_buff *skb, struct xt_action_param *par)
+typedef struct sock *(*cgroup_lookup_t)(const struct sk_buff *skb,
+ const struct net_device *indev);
+
+static bool cgroup_mt(const struct sk_buff *skb,
+ const struct xt_action_param *par,
+ cgroup_lookup_t cgroup_mt_slow)
{
const struct xt_cgroup_info *info = par->matchinfo;
+ struct sock *sk = skb->sk;
+ u32 sk_classid;
+
+ if (sk && sk_fullsock(skb->sk)) {
+ sk_classid = sk->sk_classid;
+ } else {
+ if (par->in)
+ sk = cgroup_mt_slow(skb, par->in);
+
+ if (!sk)
+ return false;
- if (skb->sk == NULL || !sk_fullsock(skb->sk))
- return false;
+ if (!sk_fullsock(sk)) {
+ sock_gen_put(sk);
+ return false;
+ }
+
+ sk_classid = sk->sk_classid;
+ sock_gen_put(sk);
+ }
+
+ return (info->id == sk_classid) ^ info->invert;
+}
- return (info->id == skb->sk->sk_classid) ^ info->invert;
+static bool
+cgroup_mt_v4(const struct sk_buff *skb, struct xt_action_param *par)
+{
+ return cgroup_mt(skb, par, nf_socket_lookup_v4);
+}
+
+#ifdef XT_HAVE_IPV6
+static bool
+cgroup_mt_v6(const struct sk_buff *skb, struct xt_action_param *par)
+{
+ return cgroup_mt(skb, par, nf_socket_lookup_v6);
}
+#endif
-static struct xt_match cgroup_mt_reg __read_mostly = {
- .name = "cgroup",
- .revision = 0,
- .family = NFPROTO_UNSPEC,
- .checkentry = cgroup_mt_check,
- .match = cgroup_mt,
- .matchsize = sizeof(struct xt_cgroup_info),
- .me = THIS_MODULE,
- .hooks = (1 << NF_INET_LOCAL_OUT) |
- (1 << NF_INET_POST_ROUTING) |
- (1 << NF_INET_LOCAL_IN),
+static struct xt_match cgroup_mt_reg[] __read_mostly = {
+ {
+ .name = "cgroup",
+ .revision = 0,
+ .family = NFPROTO_IPV4,
+ .checkentry = cgroup_mt_check,
+ .match = cgroup_mt_v4,
+ .matchsize = sizeof(struct xt_cgroup_info),
+ .me = THIS_MODULE,
+ .hooks = (1 << NF_INET_LOCAL_OUT) |
+ (1 << NF_INET_POST_ROUTING) |
+ (1 << NF_INET_LOCAL_IN),
+ },
+#ifdef XT_HAVE_IPV6
+ {
+ .name = "cgroup",
+ .revision = 0,
+ .family = NFPROTO_IPV6,
+ .checkentry = cgroup_mt_check,
+ .match = cgroup_mt_v6,
+ .matchsize = sizeof(struct xt_cgroup_info),
+ .me = THIS_MODULE,
+ .hooks = (1 << NF_INET_LOCAL_OUT) |
+ (1 << NF_INET_POST_ROUTING) |
+ (1 << NF_INET_LOCAL_IN),
+ }
+#endif
};
static int __init cgroup_mt_init(void)
{
- return xt_register_match(&cgroup_mt_reg);
+ nf_defrag_ipv4_enable();
+#ifdef XT_HAVE_IPV6
+ nf_defrag_ipv6_enable();
+#endif
+ return xt_register_matches(cgroup_mt_reg, ARRAY_SIZE(cgroup_mt_reg));
}
static void __exit cgroup_mt_exit(void)
{
- xt_unregister_match(&cgroup_mt_reg);
+ xt_unregister_matches(cgroup_mt_reg, ARRAY_SIZE(cgroup_mt_reg));
}
module_init(cgroup_mt_init);