@@ -24,16 +24,13 @@ struct cgroup_cls_state
u32 classid;
};
-extern void sock_update_classid(struct sock *sk);
+extern void sock_update_classid(struct sock *sk, struct task_struct *task);
#ifdef CONFIG_NET_CLS_CGROUP
static inline u32 task_cls_classid(struct task_struct *p)
{
int classid;
- if (in_interrupt())
- return 0;
-
rcu_read_lock();
classid = container_of(task_subsys_state(p, net_cls_subsys_id),
struct cgroup_cls_state, css)->classid;
@@ -49,9 +46,6 @@ static inline u32 task_cls_classid(struct task_struct *p)
int id;
u32 classid = 0;
- if (in_interrupt())
- return 0;
-
rcu_read_lock();
id = rcu_dereference_index_check(net_cls_subsys_id,
rcu_read_lock_held());
@@ -64,7 +58,7 @@ static inline u32 task_cls_classid(struct task_struct *p)
}
#endif
#else
-static inline void sock_update_classid(struct sock *sk)
+static inline void sock_update_classid(struct sock *sk, struct task_struct *task)
{
}
@@ -274,7 +274,7 @@ out_free_devname:
return ret;
}
-void net_prio_attach(struct cgroup *cgrp, struct cgroup_taskset *tset)
+static void net_prio_attach(struct cgroup *cgrp, struct cgroup_taskset *tset)
{
struct task_struct *p;
char *tmp = kzalloc(sizeof(char) * PATH_MAX, GFP_KERNEL);
@@ -1223,13 +1223,14 @@ static void sk_prot_free(struct proto *prot, struct sock *sk)
}
#ifdef CONFIG_CGROUPS
-void sock_update_classid(struct sock *sk)
+void sock_update_classid(struct sock *sk, struct task_struct *task)
{
u32 classid;
- rcu_read_lock(); /* doing current task, which cannot vanish. */
- classid = task_cls_classid(current);
- rcu_read_unlock();
+ if (in_interrupt())
+ return;
+
+ classid = task_cls_classid(task);
if (classid && classid != sk->sk_classid)
sk->sk_classid = classid;
}
@@ -1269,7 +1270,7 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
sock_net_set(sk, get_net(net));
atomic_set(&sk->sk_wmem_alloc, 1);
- sock_update_classid(sk);
+ sock_update_classid(sk, current);
sock_update_netprioidx(sk, current);
}
@@ -17,6 +17,7 @@
#include <linux/skbuff.h>
#include <linux/cgroup.h>
#include <linux/rcupdate.h>
+#include <linux/fdtable.h>
#include <net/rtnetlink.h>
#include <net/pkt_cls.h>
#include <net/sock.h>
@@ -53,6 +54,42 @@ static void cgrp_destroy(struct cgroup *cgrp)
kfree(cgrp_cls_state(cgrp));
}
+static void cgrp_attach(struct cgroup *cgrp, struct cgroup_taskset *tset)
+{
+ struct task_struct *p;
+
+ cgroup_taskset_for_each(p, cgrp, tset) {
+ unsigned int fd;
+ struct fdtable *fdt;
+ struct files_struct *files;
+
+ task_lock(p);
+ files = p->files;
+ if (!files) {
+ task_unlock(p);
+ continue;
+ }
+
+ spin_lock(&files->file_lock);
+ fdt = files_fdtable(files);
+ for (fd = 0; fd < fdt->max_fds; fd++) {
+ struct file *file;
+ struct socket *sock;
+ int err;
+
+ file = fcheck_files(files, fd);
+ if (!file)
+ continue;
+
+ sock = sock_from_file(file, &err);
+ if (sock)
+ sock_update_netprioidx(sock->sk, p);
+ }
+ spin_unlock(&files->file_lock);
+ task_unlock(p);
+ }
+}
+
static u64 read_classid(struct cgroup *cgrp, struct cftype *cft)
{
return cgrp_cls_state(cgrp)->classid;
@@ -77,6 +114,7 @@ struct cgroup_subsys net_cls_subsys = {
.name = "net_cls",
.create = cgrp_create,
.destroy = cgrp_destroy,
+ .attach = cgrp_attach,
#ifdef CONFIG_NET_CLS_CGROUP
.subsys_id = net_cls_subsys_id,
#endif
@@ -553,8 +553,6 @@ static inline int __sock_sendmsg_nosec(struct kiocb *iocb, struct socket *sock,
{
struct sock_iocb *si = kiocb_to_siocb(iocb);
- sock_update_classid(sock->sk);
-
si->sock = sock;
si->scm = NULL;
si->msg = msg;
@@ -717,8 +715,6 @@ static inline int __sock_recvmsg_nosec(struct kiocb *iocb, struct socket *sock,
{
struct sock_iocb *si = kiocb_to_siocb(iocb);
- sock_update_classid(sock->sk);
-
si->sock = sock;
si->scm = NULL;
si->msg = msg;
@@ -829,8 +825,6 @@ static ssize_t sock_splice_read(struct file *file, loff_t *ppos,
if (unlikely(!sock->ops->splice_read))
return -EINVAL;
- sock_update_classid(sock->sk);
-
return sock->ops->splice_read(sock, ppos, pipe, len, flags);
}
@@ -3353,8 +3347,6 @@ EXPORT_SYMBOL(kernel_setsockopt);
int kernel_sendpage(struct socket *sock, struct page *page, int offset,
size_t size, int flags)
{
- sock_update_classid(sock->sk);
-
if (sock->ops->sendpage)
return sock->ops->sendpage(sock, page, offset, size, flags);