diff mbox

[V2,2/6] cxl: Keep track of mm struct associated with a context

Message ID 1489489722-935-3-git-send-email-clombard@linux.vnet.ibm.com (mailing list archive)
State Changes Requested
Headers show

Commit Message

Christophe Lombard March 14, 2017, 11:08 a.m. UTC
The mm_struct corresponding to the current task is acquired each time
an interrupt is raised. So to simplify the code, we only get the
mm_struct when attaching an AFU context to the process.
The mm_count reference is increased to ensure that the mm_struct can't
be freed. The mm_struct will be released when the context is detached.
The reference (use count) on the struct mm is not kept to avoid a
circular dependency if the process mmaps its cxl mmio and forget to
unmap before exiting.

Signed-off-by: Christophe Lombard <clombard@linux.vnet.ibm.com>
---
 drivers/misc/cxl/api.c     | 17 ++++++++--
 drivers/misc/cxl/context.c | 26 ++++++++++++++--
 drivers/misc/cxl/cxl.h     | 13 ++++++--
 drivers/misc/cxl/fault.c   | 77 +++++-----------------------------------------
 drivers/misc/cxl/file.c    | 15 +++++++--
 5 files changed, 68 insertions(+), 80 deletions(-)

Comments

Frederic Barrat March 20, 2017, 2:27 p.m. UTC | #1
Le 14/03/2017 à 12:08, Christophe Lombard a écrit :
> The mm_struct corresponding to the current task is acquired each time
> an interrupt is raised. So to simplify the code, we only get the
> mm_struct when attaching an AFU context to the process.
> The mm_count reference is increased to ensure that the mm_struct can't
> be freed. The mm_struct will be released when the context is detached.
> The reference (use count) on the struct mm is not kept to avoid a
> circular dependency if the process mmaps its cxl mmio and forget to
> unmap before exiting.

I think it should be:
A reference on mm_users is not kept to avoid...


>
> Signed-off-by: Christophe Lombard <clombard@linux.vnet.ibm.com>
> ---
>  drivers/misc/cxl/api.c     | 17 ++++++++--
>  drivers/misc/cxl/context.c | 26 ++++++++++++++--
>  drivers/misc/cxl/cxl.h     | 13 ++++++--
>  drivers/misc/cxl/fault.c   | 77 +++++-----------------------------------------
>  drivers/misc/cxl/file.c    | 15 +++++++--
>  5 files changed, 68 insertions(+), 80 deletions(-)
>
> diff --git a/drivers/misc/cxl/api.c b/drivers/misc/cxl/api.c
> index bcc030e..1a138c8 100644
> --- a/drivers/misc/cxl/api.c
> +++ b/drivers/misc/cxl/api.c
> @@ -14,6 +14,7 @@
>  #include <linux/msi.h>
>  #include <linux/module.h>
>  #include <linux/mount.h>
> +#include <linux/sched/mm.h>
>
>  #include "cxl.h"
>
> @@ -321,19 +322,29 @@ int cxl_start_context(struct cxl_context *ctx, u64 wed,
>
>  	if (task) {
>  		ctx->pid = get_task_pid(task, PIDTYPE_PID);
> -		ctx->glpid = get_task_pid(task->group_leader, PIDTYPE_PID);
>  		kernel = false;
>  		ctx->real_mode = false;
> +
> +		/* acquire a reference to the task's mm */
> +		ctx->mm = get_task_mm(current);
> +
> +		/* ensure this mm_struct can't be freed */
> +		cxl_context_mm_count_get(ctx);
> +
> +		/* decrement the use count */
> +		if (ctx->mm)
> +			mmput(ctx->mm);
>  	}
>
>  	cxl_ctx_get();
>
>  	if ((rc = cxl_ops->attach_process(ctx, kernel, wed, 0))) {
> -		put_pid(ctx->glpid);
>  		put_pid(ctx->pid);
> -		ctx->glpid = ctx->pid = NULL;
> +		ctx->pid = NULL;
>  		cxl_adapter_context_put(ctx->afu->adapter);
>  		cxl_ctx_put();
> +		if (task)
> +			cxl_context_mm_count_put(ctx);
>  		goto out;
>  	}
>
> diff --git a/drivers/misc/cxl/context.c b/drivers/misc/cxl/context.c
> index 062bf6c..ed0a447 100644
> --- a/drivers/misc/cxl/context.c
> +++ b/drivers/misc/cxl/context.c
> @@ -17,6 +17,7 @@
>  #include <linux/debugfs.h>
>  #include <linux/slab.h>
>  #include <linux/idr.h>
> +#include <linux/sched/mm.h>
>  #include <asm/cputable.h>
>  #include <asm/current.h>
>  #include <asm/copro.h>
> @@ -41,7 +42,7 @@ int cxl_context_init(struct cxl_context *ctx, struct cxl_afu *afu, bool master)
>  	spin_lock_init(&ctx->sste_lock);
>  	ctx->afu = afu;
>  	ctx->master = master;
> -	ctx->pid = ctx->glpid = NULL; /* Set in start work ioctl */
> +	ctx->pid = NULL; /* Set in start work ioctl */
>  	mutex_init(&ctx->mapping_lock);
>  	ctx->mapping = NULL;
>
> @@ -242,12 +243,15 @@ int __detach_context(struct cxl_context *ctx)
>
>  	/* release the reference to the group leader and mm handling pid */
>  	put_pid(ctx->pid);
> -	put_pid(ctx->glpid);
>
>  	cxl_ctx_put();
>
>  	/* Decrease the attached context count on the adapter */
>  	cxl_adapter_context_put(ctx->afu->adapter);
> +
> +	/* Decrease the mm count on the context */
> +	cxl_context_mm_count_put(ctx);
> +
>  	return 0;
>  }
>
> @@ -325,3 +329,21 @@ void cxl_context_free(struct cxl_context *ctx)
>  	mutex_unlock(&ctx->afu->contexts_lock);
>  	call_rcu(&ctx->rcu, reclaim_ctx);
>  }
> +
> +void cxl_context_mm_count_get(struct cxl_context *ctx)
> +{
> +	if (ctx->mm)
> +		atomic_inc(&ctx->mm->mm_count);
> +}
> +
> +void cxl_context_mm_count_put(struct cxl_context *ctx)
> +{
> +	if (ctx->mm)
> +		mmdrop(ctx->mm);
> +}
> +
> +void cxl_context_mm_users_get(struct cxl_context *ctx)
> +{
> +	if (ctx->mm)
> +		atomic_inc(&ctx->mm->mm_users);
> +}
> diff --git a/drivers/misc/cxl/cxl.h b/drivers/misc/cxl/cxl.h
> index 79e60ec..4d1b704 100644
> --- a/drivers/misc/cxl/cxl.h
> +++ b/drivers/misc/cxl/cxl.h
> @@ -482,8 +482,6 @@ struct cxl_context {
>  	unsigned int sst_size, sst_lru;
>
>  	wait_queue_head_t wq;
> -	/* pid of the group leader associated with the pid */
> -	struct pid *glpid;
>  	/* use mm context associated with this pid for ds faults */
>  	struct pid *pid;
>  	spinlock_t lock; /* Protects pending_irq_mask, pending_fault and fault_addr */
> @@ -551,6 +549,8 @@ struct cxl_context {
>  	 * CX4 only:
>  	 */
>  	struct list_head extra_irq_contexts;
> +
> +	struct mm_struct *mm;
>  };
>
>  struct cxl_service_layer_ops {
> @@ -1024,4 +1024,13 @@ int cxl_adapter_context_lock(struct cxl *adapter);
>  /* Unlock the contexts-lock if taken. Warn and force unlock otherwise */
>  void cxl_adapter_context_unlock(struct cxl *adapter);
>
> +/* Increases the reference count to "struct mm_struct" */
> +void cxl_context_mm_count_get(struct cxl_context *ctx);
> +
> +/* Decrements the reference count to "struct mm_struct" */
> +void cxl_context_mm_count_put(struct cxl_context *ctx);
> +
> +/* Increases the reference users to "struct mm_struct" */
> +void cxl_context_mm_users_get(struct cxl_context *ctx);
> +
>  #endif
> diff --git a/drivers/misc/cxl/fault.c b/drivers/misc/cxl/fault.c
> index 2fa015c..14a5bfa 100644
> --- a/drivers/misc/cxl/fault.c
> +++ b/drivers/misc/cxl/fault.c
> @@ -170,81 +170,19 @@ static void cxl_handle_page_fault(struct cxl_context *ctx,
>  }
>
>  /*
> - * Returns the mm_struct corresponding to the context ctx via ctx->pid
> - * In case the task has exited we use the task group leader accessible
> - * via ctx->glpid to find the next task in the thread group that has a
> - * valid  mm_struct associated with it. If a task with valid mm_struct
> - * is found the ctx->pid is updated to use the task struct for subsequent
> - * translations. In case no valid mm_struct is found in the task group to
> - * service the fault a NULL is returned.
> + * Returns the mm_struct corresponding to the context ctx.
> + * mm_users == 0, the context may be in the process of being closed.
>   */
>  static struct mm_struct *get_mem_context(struct cxl_context *ctx)
>  {
> -	struct task_struct *task = NULL;
> -	struct mm_struct *mm = NULL;
> -	struct pid *old_pid = ctx->pid;
> -
> -	if (old_pid == NULL) {
> -		pr_warn("%s: Invalid context for pe=%d\n",
> -			 __func__, ctx->pe);
> +	if (ctx->mm == NULL)
>  		return NULL;
> -	}
> -
> -	task = get_pid_task(old_pid, PIDTYPE_PID);
> -
> -	/*
> -	 * pid_alive may look racy but this saves us from costly
> -	 * get_task_mm when the task is a zombie. In worst case
> -	 * we may think a task is alive, which is about to die
> -	 * but get_task_mm will return NULL.
> -	 */
> -	if (task != NULL && pid_alive(task))
> -		mm = get_task_mm(task);
>
> -	/* release the task struct that was taken earlier */
> -	if (task)
> -		put_task_struct(task);
> -	else
> -		pr_devel("%s: Context owning pid=%i for pe=%i dead\n",
> -			__func__, pid_nr(old_pid), ctx->pe);
> -
> -	/*
> -	 * If we couldn't find the mm context then use the group
> -	 * leader to iterate over the task group and find a task
> -	 * that gives us mm_struct.
> -	 */
> -	if (unlikely(mm == NULL && ctx->glpid != NULL)) {
> -
> -		rcu_read_lock();
> -		task = pid_task(ctx->glpid, PIDTYPE_PID);
> -		if (task)
> -			do {
> -				mm = get_task_mm(task);
> -				if (mm) {
> -					ctx->pid = get_task_pid(task,
> -								PIDTYPE_PID);
> -					break;
> -				}
> -				task = next_thread(task);
> -			} while (task && !thread_group_leader(task));
> -		rcu_read_unlock();
> -
> -		/* check if we switched pid */
> -		if (ctx->pid != old_pid) {
> -			if (mm)
> -				pr_devel("%s:pe=%i switch pid %i->%i\n",
> -					 __func__, ctx->pe, pid_nr(old_pid),
> -					 pid_nr(ctx->pid));
> -			else
> -				pr_devel("%s:Cannot find mm for pid=%i\n",
> -					 __func__, pid_nr(old_pid));
> -
> -			/* drop the reference to older pid */
> -			put_pid(old_pid);
> -		}
> -	}
> +	if (atomic_read(&ctx->mm->mm_users) == 0)
> +		return NULL;
>
> -	return mm;
> +	cxl_context_mm_users_get(ctx);
> +	return ctx->mm;

It should be done atomically:

	if (!atomic_inc_not_zero(&ctx->mm->mm_users))
		return ctx->mm;

	return NULL;

in which case, we don't need cxl_context_mm_users_get() any more.


   Fred


>  }
>
>
> @@ -282,7 +220,6 @@ void cxl_handle_fault(struct work_struct *fault_work)
>  	if (!ctx->kernel) {
>
>  		mm = get_mem_context(ctx);
> -		/* indicates all the thread in task group have exited */
>  		if (mm == NULL) {
>  			pr_devel("%s: unable to get mm for pe=%d pid=%i\n",
>  				 __func__, ctx->pe, pid_nr(ctx->pid));
> diff --git a/drivers/misc/cxl/file.c b/drivers/misc/cxl/file.c
> index e7139c7..17b433f 100644
> --- a/drivers/misc/cxl/file.c
> +++ b/drivers/misc/cxl/file.c
> @@ -18,6 +18,7 @@
>  #include <linux/fs.h>
>  #include <linux/mm.h>
>  #include <linux/slab.h>
> +#include <linux/sched/mm.h>
>  #include <asm/cputable.h>
>  #include <asm/current.h>
>  #include <asm/copro.h>
> @@ -216,8 +217,16 @@ static long afu_ioctl_start_work(struct cxl_context *ctx,
>  	 * process is still accessible.
>  	 */
>  	ctx->pid = get_task_pid(current, PIDTYPE_PID);
> -	ctx->glpid = get_task_pid(current->group_leader, PIDTYPE_PID);
>
> +	/* acquire a reference to the task's mm */
> +	ctx->mm = get_task_mm(current);
> +
> +	/* ensure this mm_struct can't be freed */
> +	cxl_context_mm_count_get(ctx);
> +
> +	/* decrement the use count */
> +	if (ctx->mm)
> +		mmput(ctx->mm);
>
>  	trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr);
>
> @@ -225,9 +234,9 @@ static long afu_ioctl_start_work(struct cxl_context *ctx,
>  							amr))) {
>  		afu_release_irqs(ctx, ctx);
>  		cxl_adapter_context_put(ctx->afu->adapter);
> -		put_pid(ctx->glpid);
>  		put_pid(ctx->pid);
> -		ctx->glpid = ctx->pid = NULL;
> +		ctx->pid = NULL;
> +		cxl_context_mm_count_put(ctx);
>  		goto out;
>  	}
>
Frederic Barrat March 21, 2017, 11:28 a.m. UTC | #2
Another thought about that patch. Now that we keep track of the mm 
associated to a context, I think we can simplify slightly the function 
_cxl_slbia() in main.c, where we look for the mm based on the pid. We 
now have the information readily available.

   Fred


Le 14/03/2017 à 12:08, Christophe Lombard a écrit :
> The mm_struct corresponding to the current task is acquired each time
> an interrupt is raised. So to simplify the code, we only get the
> mm_struct when attaching an AFU context to the process.
> The mm_count reference is increased to ensure that the mm_struct can't
> be freed. The mm_struct will be released when the context is detached.
> The reference (use count) on the struct mm is not kept to avoid a
> circular dependency if the process mmaps its cxl mmio and forget to
> unmap before exiting.
>
> Signed-off-by: Christophe Lombard <clombard@linux.vnet.ibm.com>
> ---
>  drivers/misc/cxl/api.c     | 17 ++++++++--
>  drivers/misc/cxl/context.c | 26 ++++++++++++++--
>  drivers/misc/cxl/cxl.h     | 13 ++++++--
>  drivers/misc/cxl/fault.c   | 77 +++++-----------------------------------------
>  drivers/misc/cxl/file.c    | 15 +++++++--
>  5 files changed, 68 insertions(+), 80 deletions(-)
>
> diff --git a/drivers/misc/cxl/api.c b/drivers/misc/cxl/api.c
> index bcc030e..1a138c8 100644
> --- a/drivers/misc/cxl/api.c
> +++ b/drivers/misc/cxl/api.c
> @@ -14,6 +14,7 @@
>  #include <linux/msi.h>
>  #include <linux/module.h>
>  #include <linux/mount.h>
> +#include <linux/sched/mm.h>
>
>  #include "cxl.h"
>
> @@ -321,19 +322,29 @@ int cxl_start_context(struct cxl_context *ctx, u64 wed,
>
>  	if (task) {
>  		ctx->pid = get_task_pid(task, PIDTYPE_PID);
> -		ctx->glpid = get_task_pid(task->group_leader, PIDTYPE_PID);
>  		kernel = false;
>  		ctx->real_mode = false;
> +
> +		/* acquire a reference to the task's mm */
> +		ctx->mm = get_task_mm(current);
> +
> +		/* ensure this mm_struct can't be freed */
> +		cxl_context_mm_count_get(ctx);
> +
> +		/* decrement the use count */
> +		if (ctx->mm)
> +			mmput(ctx->mm);
>  	}
>
>  	cxl_ctx_get();
>
>  	if ((rc = cxl_ops->attach_process(ctx, kernel, wed, 0))) {
> -		put_pid(ctx->glpid);
>  		put_pid(ctx->pid);
> -		ctx->glpid = ctx->pid = NULL;
> +		ctx->pid = NULL;
>  		cxl_adapter_context_put(ctx->afu->adapter);
>  		cxl_ctx_put();
> +		if (task)
> +			cxl_context_mm_count_put(ctx);
>  		goto out;
>  	}
>
> diff --git a/drivers/misc/cxl/context.c b/drivers/misc/cxl/context.c
> index 062bf6c..ed0a447 100644
> --- a/drivers/misc/cxl/context.c
> +++ b/drivers/misc/cxl/context.c
> @@ -17,6 +17,7 @@
>  #include <linux/debugfs.h>
>  #include <linux/slab.h>
>  #include <linux/idr.h>
> +#include <linux/sched/mm.h>
>  #include <asm/cputable.h>
>  #include <asm/current.h>
>  #include <asm/copro.h>
> @@ -41,7 +42,7 @@ int cxl_context_init(struct cxl_context *ctx, struct cxl_afu *afu, bool master)
>  	spin_lock_init(&ctx->sste_lock);
>  	ctx->afu = afu;
>  	ctx->master = master;
> -	ctx->pid = ctx->glpid = NULL; /* Set in start work ioctl */
> +	ctx->pid = NULL; /* Set in start work ioctl */
>  	mutex_init(&ctx->mapping_lock);
>  	ctx->mapping = NULL;
>
> @@ -242,12 +243,15 @@ int __detach_context(struct cxl_context *ctx)
>
>  	/* release the reference to the group leader and mm handling pid */
>  	put_pid(ctx->pid);
> -	put_pid(ctx->glpid);
>
>  	cxl_ctx_put();
>
>  	/* Decrease the attached context count on the adapter */
>  	cxl_adapter_context_put(ctx->afu->adapter);
> +
> +	/* Decrease the mm count on the context */
> +	cxl_context_mm_count_put(ctx);
> +
>  	return 0;
>  }
>
> @@ -325,3 +329,21 @@ void cxl_context_free(struct cxl_context *ctx)
>  	mutex_unlock(&ctx->afu->contexts_lock);
>  	call_rcu(&ctx->rcu, reclaim_ctx);
>  }
> +
> +void cxl_context_mm_count_get(struct cxl_context *ctx)
> +{
> +	if (ctx->mm)
> +		atomic_inc(&ctx->mm->mm_count);
> +}
> +
> +void cxl_context_mm_count_put(struct cxl_context *ctx)
> +{
> +	if (ctx->mm)
> +		mmdrop(ctx->mm);
> +}
> +
> +void cxl_context_mm_users_get(struct cxl_context *ctx)
> +{
> +	if (ctx->mm)
> +		atomic_inc(&ctx->mm->mm_users);
> +}
> diff --git a/drivers/misc/cxl/cxl.h b/drivers/misc/cxl/cxl.h
> index 79e60ec..4d1b704 100644
> --- a/drivers/misc/cxl/cxl.h
> +++ b/drivers/misc/cxl/cxl.h
> @@ -482,8 +482,6 @@ struct cxl_context {
>  	unsigned int sst_size, sst_lru;
>
>  	wait_queue_head_t wq;
> -	/* pid of the group leader associated with the pid */
> -	struct pid *glpid;
>  	/* use mm context associated with this pid for ds faults */
>  	struct pid *pid;
>  	spinlock_t lock; /* Protects pending_irq_mask, pending_fault and fault_addr */
> @@ -551,6 +549,8 @@ struct cxl_context {
>  	 * CX4 only:
>  	 */
>  	struct list_head extra_irq_contexts;
> +
> +	struct mm_struct *mm;
>  };
>
>  struct cxl_service_layer_ops {
> @@ -1024,4 +1024,13 @@ int cxl_adapter_context_lock(struct cxl *adapter);
>  /* Unlock the contexts-lock if taken. Warn and force unlock otherwise */
>  void cxl_adapter_context_unlock(struct cxl *adapter);
>
> +/* Increases the reference count to "struct mm_struct" */
> +void cxl_context_mm_count_get(struct cxl_context *ctx);
> +
> +/* Decrements the reference count to "struct mm_struct" */
> +void cxl_context_mm_count_put(struct cxl_context *ctx);
> +
> +/* Increases the reference users to "struct mm_struct" */
> +void cxl_context_mm_users_get(struct cxl_context *ctx);
> +
>  #endif
> diff --git a/drivers/misc/cxl/fault.c b/drivers/misc/cxl/fault.c
> index 2fa015c..14a5bfa 100644
> --- a/drivers/misc/cxl/fault.c
> +++ b/drivers/misc/cxl/fault.c
> @@ -170,81 +170,19 @@ static void cxl_handle_page_fault(struct cxl_context *ctx,
>  }
>
>  /*
> - * Returns the mm_struct corresponding to the context ctx via ctx->pid
> - * In case the task has exited we use the task group leader accessible
> - * via ctx->glpid to find the next task in the thread group that has a
> - * valid  mm_struct associated with it. If a task with valid mm_struct
> - * is found the ctx->pid is updated to use the task struct for subsequent
> - * translations. In case no valid mm_struct is found in the task group to
> - * service the fault a NULL is returned.
> + * Returns the mm_struct corresponding to the context ctx.
> + * mm_users == 0, the context may be in the process of being closed.
>   */
>  static struct mm_struct *get_mem_context(struct cxl_context *ctx)
>  {
> -	struct task_struct *task = NULL;
> -	struct mm_struct *mm = NULL;
> -	struct pid *old_pid = ctx->pid;
> -
> -	if (old_pid == NULL) {
> -		pr_warn("%s: Invalid context for pe=%d\n",
> -			 __func__, ctx->pe);
> +	if (ctx->mm == NULL)
>  		return NULL;
> -	}
> -
> -	task = get_pid_task(old_pid, PIDTYPE_PID);
> -
> -	/*
> -	 * pid_alive may look racy but this saves us from costly
> -	 * get_task_mm when the task is a zombie. In worst case
> -	 * we may think a task is alive, which is about to die
> -	 * but get_task_mm will return NULL.
> -	 */
> -	if (task != NULL && pid_alive(task))
> -		mm = get_task_mm(task);
>
> -	/* release the task struct that was taken earlier */
> -	if (task)
> -		put_task_struct(task);
> -	else
> -		pr_devel("%s: Context owning pid=%i for pe=%i dead\n",
> -			__func__, pid_nr(old_pid), ctx->pe);
> -
> -	/*
> -	 * If we couldn't find the mm context then use the group
> -	 * leader to iterate over the task group and find a task
> -	 * that gives us mm_struct.
> -	 */
> -	if (unlikely(mm == NULL && ctx->glpid != NULL)) {
> -
> -		rcu_read_lock();
> -		task = pid_task(ctx->glpid, PIDTYPE_PID);
> -		if (task)
> -			do {
> -				mm = get_task_mm(task);
> -				if (mm) {
> -					ctx->pid = get_task_pid(task,
> -								PIDTYPE_PID);
> -					break;
> -				}
> -				task = next_thread(task);
> -			} while (task && !thread_group_leader(task));
> -		rcu_read_unlock();
> -
> -		/* check if we switched pid */
> -		if (ctx->pid != old_pid) {
> -			if (mm)
> -				pr_devel("%s:pe=%i switch pid %i->%i\n",
> -					 __func__, ctx->pe, pid_nr(old_pid),
> -					 pid_nr(ctx->pid));
> -			else
> -				pr_devel("%s:Cannot find mm for pid=%i\n",
> -					 __func__, pid_nr(old_pid));
> -
> -			/* drop the reference to older pid */
> -			put_pid(old_pid);
> -		}
> -	}
> +	if (atomic_read(&ctx->mm->mm_users) == 0)
> +		return NULL;
>
> -	return mm;
> +	cxl_context_mm_users_get(ctx);
> +	return ctx->mm;
>  }
>
>
> @@ -282,7 +220,6 @@ void cxl_handle_fault(struct work_struct *fault_work)
>  	if (!ctx->kernel) {
>
>  		mm = get_mem_context(ctx);
> -		/* indicates all the thread in task group have exited */
>  		if (mm == NULL) {
>  			pr_devel("%s: unable to get mm for pe=%d pid=%i\n",
>  				 __func__, ctx->pe, pid_nr(ctx->pid));
> diff --git a/drivers/misc/cxl/file.c b/drivers/misc/cxl/file.c
> index e7139c7..17b433f 100644
> --- a/drivers/misc/cxl/file.c
> +++ b/drivers/misc/cxl/file.c
> @@ -18,6 +18,7 @@
>  #include <linux/fs.h>
>  #include <linux/mm.h>
>  #include <linux/slab.h>
> +#include <linux/sched/mm.h>
>  #include <asm/cputable.h>
>  #include <asm/current.h>
>  #include <asm/copro.h>
> @@ -216,8 +217,16 @@ static long afu_ioctl_start_work(struct cxl_context *ctx,
>  	 * process is still accessible.
>  	 */
>  	ctx->pid = get_task_pid(current, PIDTYPE_PID);
> -	ctx->glpid = get_task_pid(current->group_leader, PIDTYPE_PID);
>
> +	/* acquire a reference to the task's mm */
> +	ctx->mm = get_task_mm(current);
> +
> +	/* ensure this mm_struct can't be freed */
> +	cxl_context_mm_count_get(ctx);
> +
> +	/* decrement the use count */
> +	if (ctx->mm)
> +		mmput(ctx->mm);
>
>  	trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr);
>
> @@ -225,9 +234,9 @@ static long afu_ioctl_start_work(struct cxl_context *ctx,
>  							amr))) {
>  		afu_release_irqs(ctx, ctx);
>  		cxl_adapter_context_put(ctx->afu->adapter);
> -		put_pid(ctx->glpid);
>  		put_pid(ctx->pid);
> -		ctx->glpid = ctx->pid = NULL;
> +		ctx->pid = NULL;
> +		cxl_context_mm_count_put(ctx);
>  		goto out;
>  	}
>
diff mbox

Patch

diff --git a/drivers/misc/cxl/api.c b/drivers/misc/cxl/api.c
index bcc030e..1a138c8 100644
--- a/drivers/misc/cxl/api.c
+++ b/drivers/misc/cxl/api.c
@@ -14,6 +14,7 @@ 
 #include <linux/msi.h>
 #include <linux/module.h>
 #include <linux/mount.h>
+#include <linux/sched/mm.h>
 
 #include "cxl.h"
 
@@ -321,19 +322,29 @@  int cxl_start_context(struct cxl_context *ctx, u64 wed,
 
 	if (task) {
 		ctx->pid = get_task_pid(task, PIDTYPE_PID);
-		ctx->glpid = get_task_pid(task->group_leader, PIDTYPE_PID);
 		kernel = false;
 		ctx->real_mode = false;
+
+		/* acquire a reference to the task's mm */
+		ctx->mm = get_task_mm(current);
+
+		/* ensure this mm_struct can't be freed */
+		cxl_context_mm_count_get(ctx);
+
+		/* decrement the use count */
+		if (ctx->mm)
+			mmput(ctx->mm);
 	}
 
 	cxl_ctx_get();
 
 	if ((rc = cxl_ops->attach_process(ctx, kernel, wed, 0))) {
-		put_pid(ctx->glpid);
 		put_pid(ctx->pid);
-		ctx->glpid = ctx->pid = NULL;
+		ctx->pid = NULL;
 		cxl_adapter_context_put(ctx->afu->adapter);
 		cxl_ctx_put();
+		if (task)
+			cxl_context_mm_count_put(ctx);
 		goto out;
 	}
 
diff --git a/drivers/misc/cxl/context.c b/drivers/misc/cxl/context.c
index 062bf6c..ed0a447 100644
--- a/drivers/misc/cxl/context.c
+++ b/drivers/misc/cxl/context.c
@@ -17,6 +17,7 @@ 
 #include <linux/debugfs.h>
 #include <linux/slab.h>
 #include <linux/idr.h>
+#include <linux/sched/mm.h>
 #include <asm/cputable.h>
 #include <asm/current.h>
 #include <asm/copro.h>
@@ -41,7 +42,7 @@  int cxl_context_init(struct cxl_context *ctx, struct cxl_afu *afu, bool master)
 	spin_lock_init(&ctx->sste_lock);
 	ctx->afu = afu;
 	ctx->master = master;
-	ctx->pid = ctx->glpid = NULL; /* Set in start work ioctl */
+	ctx->pid = NULL; /* Set in start work ioctl */
 	mutex_init(&ctx->mapping_lock);
 	ctx->mapping = NULL;
 
@@ -242,12 +243,15 @@  int __detach_context(struct cxl_context *ctx)
 
 	/* release the reference to the group leader and mm handling pid */
 	put_pid(ctx->pid);
-	put_pid(ctx->glpid);
 
 	cxl_ctx_put();
 
 	/* Decrease the attached context count on the adapter */
 	cxl_adapter_context_put(ctx->afu->adapter);
+
+	/* Decrease the mm count on the context */
+	cxl_context_mm_count_put(ctx);
+
 	return 0;
 }
 
@@ -325,3 +329,21 @@  void cxl_context_free(struct cxl_context *ctx)
 	mutex_unlock(&ctx->afu->contexts_lock);
 	call_rcu(&ctx->rcu, reclaim_ctx);
 }
+
+void cxl_context_mm_count_get(struct cxl_context *ctx)
+{
+	if (ctx->mm)
+		atomic_inc(&ctx->mm->mm_count);
+}
+
+void cxl_context_mm_count_put(struct cxl_context *ctx)
+{
+	if (ctx->mm)
+		mmdrop(ctx->mm);
+}
+
+void cxl_context_mm_users_get(struct cxl_context *ctx)
+{
+	if (ctx->mm)
+		atomic_inc(&ctx->mm->mm_users);
+}
diff --git a/drivers/misc/cxl/cxl.h b/drivers/misc/cxl/cxl.h
index 79e60ec..4d1b704 100644
--- a/drivers/misc/cxl/cxl.h
+++ b/drivers/misc/cxl/cxl.h
@@ -482,8 +482,6 @@  struct cxl_context {
 	unsigned int sst_size, sst_lru;
 
 	wait_queue_head_t wq;
-	/* pid of the group leader associated with the pid */
-	struct pid *glpid;
 	/* use mm context associated with this pid for ds faults */
 	struct pid *pid;
 	spinlock_t lock; /* Protects pending_irq_mask, pending_fault and fault_addr */
@@ -551,6 +549,8 @@  struct cxl_context {
 	 * CX4 only:
 	 */
 	struct list_head extra_irq_contexts;
+
+	struct mm_struct *mm;
 };
 
 struct cxl_service_layer_ops {
@@ -1024,4 +1024,13 @@  int cxl_adapter_context_lock(struct cxl *adapter);
 /* Unlock the contexts-lock if taken. Warn and force unlock otherwise */
 void cxl_adapter_context_unlock(struct cxl *adapter);
 
+/* Increases the reference count to "struct mm_struct" */
+void cxl_context_mm_count_get(struct cxl_context *ctx);
+
+/* Decrements the reference count to "struct mm_struct" */
+void cxl_context_mm_count_put(struct cxl_context *ctx);
+
+/* Increases the reference users to "struct mm_struct" */
+void cxl_context_mm_users_get(struct cxl_context *ctx);
+
 #endif
diff --git a/drivers/misc/cxl/fault.c b/drivers/misc/cxl/fault.c
index 2fa015c..14a5bfa 100644
--- a/drivers/misc/cxl/fault.c
+++ b/drivers/misc/cxl/fault.c
@@ -170,81 +170,19 @@  static void cxl_handle_page_fault(struct cxl_context *ctx,
 }
 
 /*
- * Returns the mm_struct corresponding to the context ctx via ctx->pid
- * In case the task has exited we use the task group leader accessible
- * via ctx->glpid to find the next task in the thread group that has a
- * valid  mm_struct associated with it. If a task with valid mm_struct
- * is found the ctx->pid is updated to use the task struct for subsequent
- * translations. In case no valid mm_struct is found in the task group to
- * service the fault a NULL is returned.
+ * Returns the mm_struct corresponding to the context ctx.
+ * mm_users == 0, the context may be in the process of being closed.
  */
 static struct mm_struct *get_mem_context(struct cxl_context *ctx)
 {
-	struct task_struct *task = NULL;
-	struct mm_struct *mm = NULL;
-	struct pid *old_pid = ctx->pid;
-
-	if (old_pid == NULL) {
-		pr_warn("%s: Invalid context for pe=%d\n",
-			 __func__, ctx->pe);
+	if (ctx->mm == NULL)
 		return NULL;
-	}
-
-	task = get_pid_task(old_pid, PIDTYPE_PID);
-
-	/*
-	 * pid_alive may look racy but this saves us from costly
-	 * get_task_mm when the task is a zombie. In worst case
-	 * we may think a task is alive, which is about to die
-	 * but get_task_mm will return NULL.
-	 */
-	if (task != NULL && pid_alive(task))
-		mm = get_task_mm(task);
 
-	/* release the task struct that was taken earlier */
-	if (task)
-		put_task_struct(task);
-	else
-		pr_devel("%s: Context owning pid=%i for pe=%i dead\n",
-			__func__, pid_nr(old_pid), ctx->pe);
-
-	/*
-	 * If we couldn't find the mm context then use the group
-	 * leader to iterate over the task group and find a task
-	 * that gives us mm_struct.
-	 */
-	if (unlikely(mm == NULL && ctx->glpid != NULL)) {
-
-		rcu_read_lock();
-		task = pid_task(ctx->glpid, PIDTYPE_PID);
-		if (task)
-			do {
-				mm = get_task_mm(task);
-				if (mm) {
-					ctx->pid = get_task_pid(task,
-								PIDTYPE_PID);
-					break;
-				}
-				task = next_thread(task);
-			} while (task && !thread_group_leader(task));
-		rcu_read_unlock();
-
-		/* check if we switched pid */
-		if (ctx->pid != old_pid) {
-			if (mm)
-				pr_devel("%s:pe=%i switch pid %i->%i\n",
-					 __func__, ctx->pe, pid_nr(old_pid),
-					 pid_nr(ctx->pid));
-			else
-				pr_devel("%s:Cannot find mm for pid=%i\n",
-					 __func__, pid_nr(old_pid));
-
-			/* drop the reference to older pid */
-			put_pid(old_pid);
-		}
-	}
+	if (atomic_read(&ctx->mm->mm_users) == 0)
+		return NULL;
 
-	return mm;
+	cxl_context_mm_users_get(ctx);
+	return ctx->mm;
 }
 
 
@@ -282,7 +220,6 @@  void cxl_handle_fault(struct work_struct *fault_work)
 	if (!ctx->kernel) {
 
 		mm = get_mem_context(ctx);
-		/* indicates all the thread in task group have exited */
 		if (mm == NULL) {
 			pr_devel("%s: unable to get mm for pe=%d pid=%i\n",
 				 __func__, ctx->pe, pid_nr(ctx->pid));
diff --git a/drivers/misc/cxl/file.c b/drivers/misc/cxl/file.c
index e7139c7..17b433f 100644
--- a/drivers/misc/cxl/file.c
+++ b/drivers/misc/cxl/file.c
@@ -18,6 +18,7 @@ 
 #include <linux/fs.h>
 #include <linux/mm.h>
 #include <linux/slab.h>
+#include <linux/sched/mm.h>
 #include <asm/cputable.h>
 #include <asm/current.h>
 #include <asm/copro.h>
@@ -216,8 +217,16 @@  static long afu_ioctl_start_work(struct cxl_context *ctx,
 	 * process is still accessible.
 	 */
 	ctx->pid = get_task_pid(current, PIDTYPE_PID);
-	ctx->glpid = get_task_pid(current->group_leader, PIDTYPE_PID);
 
+	/* acquire a reference to the task's mm */
+	ctx->mm = get_task_mm(current);
+
+	/* ensure this mm_struct can't be freed */
+	cxl_context_mm_count_get(ctx);
+
+	/* decrement the use count */
+	if (ctx->mm)
+		mmput(ctx->mm);
 
 	trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr);
 
@@ -225,9 +234,9 @@  static long afu_ioctl_start_work(struct cxl_context *ctx,
 							amr))) {
 		afu_release_irqs(ctx, ctx);
 		cxl_adapter_context_put(ctx->afu->adapter);
-		put_pid(ctx->glpid);
 		put_pid(ctx->pid);
-		ctx->glpid = ctx->pid = NULL;
+		ctx->pid = NULL;
+		cxl_context_mm_count_put(ctx);
 		goto out;
 	}