diff mbox series

[v6,22/24] mm: Speculative page fault handler return VMA

Message ID 1515777968-867-23-git-send-email-ldufour@linux.vnet.ibm.com (mailing list archive)
State Not Applicable
Headers show
Series Speculative page faults | expand

Commit Message

Laurent Dufour Jan. 12, 2018, 5:26 p.m. UTC
When the speculative page fault handler is returning VM_RETRY, there is a
chance that VMA fetched without grabbing the mmap_sem can be reused by the
legacy page fault handler.  By reusing it, we avoid calling find_vma()
again. To achieve, that we must ensure that the VMA structure will not be
freed in our back. This is done by getting the reference on it (get_vma())
and by assuming that the caller will call the new service
can_reuse_spf_vma() once it has grabbed the mmap_sem.

can_reuse_spf_vma() is first checking that the VMA is still in the RB tree
, and then that the VMA's boundaries matched the passed address and release
the reference on the VMA so that it can be freed if needed.

In the case the VMA is freed, can_reuse_spf_vma() will have returned false
as the VMA is no more in the RB tree.

Signed-off-by: Laurent Dufour <ldufour@linux.vnet.ibm.com>
---
 include/linux/mm.h |   5 +-
 mm/memory.c        | 136 +++++++++++++++++++++++++++++++++--------------------
 2 files changed, 88 insertions(+), 53 deletions(-)

Comments

Matthew Wilcox (Oracle) Jan. 12, 2018, 7:02 p.m. UTC | #1
On Fri, Jan 12, 2018 at 06:26:06PM +0100, Laurent Dufour wrote:
> @@ -1354,7 +1354,10 @@ extern int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
>  		unsigned int flags);
>  #ifdef CONFIG_SPF
>  extern int handle_speculative_fault(struct mm_struct *mm,
> +				    unsigned long address, unsigned int flags,
> +				    struct vm_area_struct **vma);

I think this shows that we need to create 'struct vm_fault' on the stack
in the arch code and then pass it to handle_speculative_fault(), followed
by handle_mm_fault().  That should be quite a nice cleanup actually.
I know that's only 30+ architectures to change ;-)
Matthew Wilcox (Oracle) Jan. 13, 2018, 4:23 a.m. UTC | #2
On Fri, Jan 12, 2018 at 11:02:51AM -0800, Matthew Wilcox wrote:
> On Fri, Jan 12, 2018 at 06:26:06PM +0100, Laurent Dufour wrote:
> > @@ -1354,7 +1354,10 @@ extern int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
> >  		unsigned int flags);
> >  #ifdef CONFIG_SPF
> >  extern int handle_speculative_fault(struct mm_struct *mm,
> > +				    unsigned long address, unsigned int flags,
> > +				    struct vm_area_struct **vma);
> 
> I think this shows that we need to create 'struct vm_fault' on the stack
> in the arch code and then pass it to handle_speculative_fault(), followed
> by handle_mm_fault().  That should be quite a nice cleanup actually.
> I know that's only 30+ architectures to change ;-)

Of course, we don't need to change them all.  Try this:

Subject: [PATCH] Add vm_handle_fault

For the speculative fault handler, we want to create the struct vm_fault
on the stack in the arch code and pass it into the generic mm code.
To avoid changing 30+ architectures, leave handle_mm_fault with its
current function signature and move its guts into the new vm_handle_fault
function.  Even this saves a nice 172 bytes on the random x86-64 .config
I happen to have around.

Signed-off-by: Matthew Wilcox <mawilcox@microsoft.com>

diff --git a/mm/memory.c b/mm/memory.c
index 5eb3d2524bdc..403934297a3d 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -3977,36 +3977,28 @@ static int handle_pte_fault(struct vm_fault *vmf)
  * The mmap_sem may have been released depending on flags and our
  * return value.  See filemap_fault() and __lock_page_or_retry().
  */
-static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
-		unsigned int flags)
+static int __handle_mm_fault(struct vm_fault *vmf)
 {
-	struct vm_fault vmf = {
-		.vma = vma,
-		.address = address & PAGE_MASK,
-		.flags = flags,
-		.pgoff = linear_page_index(vma, address),
-		.gfp_mask = __get_fault_gfp_mask(vma),
-	};
-	unsigned int dirty = flags & FAULT_FLAG_WRITE;
-	struct mm_struct *mm = vma->vm_mm;
+	unsigned int dirty = vmf->flags & FAULT_FLAG_WRITE;
+	struct mm_struct *mm = vmf->vma->vm_mm;
 	pgd_t *pgd;
 	p4d_t *p4d;
 	int ret;
 
-	pgd = pgd_offset(mm, address);
-	p4d = p4d_alloc(mm, pgd, address);
+	pgd = pgd_offset(mm, vmf->address);
+	p4d = p4d_alloc(mm, pgd, vmf->address);
 	if (!p4d)
 		return VM_FAULT_OOM;
 
-	vmf.pud = pud_alloc(mm, p4d, address);
-	if (!vmf.pud)
+	vmf->pud = pud_alloc(mm, p4d, vmf->address);
+	if (!vmf->pud)
 		return VM_FAULT_OOM;
-	if (pud_none(*vmf.pud) && transparent_hugepage_enabled(vma)) {
-		ret = create_huge_pud(&vmf);
+	if (pud_none(*vmf->pud) && transparent_hugepage_enabled(vmf->vma)) {
+		ret = create_huge_pud(vmf);
 		if (!(ret & VM_FAULT_FALLBACK))
 			return ret;
 	} else {
-		pud_t orig_pud = *vmf.pud;
+		pud_t orig_pud = *vmf->pud;
 
 		barrier();
 		if (pud_trans_huge(orig_pud) || pud_devmap(orig_pud)) {
@@ -4014,50 +4006,51 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 			/* NUMA case for anonymous PUDs would go here */
 
 			if (dirty && !pud_access_permitted(orig_pud, WRITE)) {
-				ret = wp_huge_pud(&vmf, orig_pud);
+				ret = wp_huge_pud(vmf, orig_pud);
 				if (!(ret & VM_FAULT_FALLBACK))
 					return ret;
 			} else {
-				huge_pud_set_accessed(&vmf, orig_pud);
+				huge_pud_set_accessed(vmf, orig_pud);
 				return 0;
 			}
 		}
 	}
 
-	vmf.pmd = pmd_alloc(mm, vmf.pud, address);
-	if (!vmf.pmd)
+	vmf->pmd = pmd_alloc(mm, vmf->pud, vmf->address);
+	if (!vmf->pmd)
 		return VM_FAULT_OOM;
-	if (pmd_none(*vmf.pmd) && transparent_hugepage_enabled(vma)) {
-		ret = create_huge_pmd(&vmf);
+	if (pmd_none(*vmf->pmd) && transparent_hugepage_enabled(vmf->vma)) {
+		ret = create_huge_pmd(vmf);
 		if (!(ret & VM_FAULT_FALLBACK))
 			return ret;
 	} else {
-		pmd_t orig_pmd = *vmf.pmd;
+		pmd_t orig_pmd = *vmf->pmd;
 
 		barrier();
 		if (unlikely(is_swap_pmd(orig_pmd))) {
 			VM_BUG_ON(thp_migration_supported() &&
 					  !is_pmd_migration_entry(orig_pmd));
 			if (is_pmd_migration_entry(orig_pmd))
-				pmd_migration_entry_wait(mm, vmf.pmd);
+				pmd_migration_entry_wait(mm, vmf->pmd);
 			return 0;
 		}
 		if (pmd_trans_huge(orig_pmd) || pmd_devmap(orig_pmd)) {
-			if (pmd_protnone(orig_pmd) && vma_is_accessible(vma))
-				return do_huge_pmd_numa_page(&vmf, orig_pmd);
+			if (pmd_protnone(orig_pmd) &&
+						vma_is_accessible(vmf->vma))
+				return do_huge_pmd_numa_page(vmf, orig_pmd);
 
 			if (dirty && !pmd_access_permitted(orig_pmd, WRITE)) {
-				ret = wp_huge_pmd(&vmf, orig_pmd);
+				ret = wp_huge_pmd(vmf, orig_pmd);
 				if (!(ret & VM_FAULT_FALLBACK))
 					return ret;
 			} else {
-				huge_pmd_set_accessed(&vmf, orig_pmd);
+				huge_pmd_set_accessed(vmf, orig_pmd);
 				return 0;
 			}
 		}
 	}
 
-	return handle_pte_fault(&vmf);
+	return handle_pte_fault(vmf);
 }
 
 /*
@@ -4066,9 +4059,10 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
  * The mmap_sem may have been released depending on flags and our
  * return value.  See filemap_fault() and __lock_page_or_retry().
  */
-int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
-		unsigned int flags)
+int vm_handle_fault(struct vm_fault *vmf)
 {
+	unsigned int flags = vmf->flags;
+	struct vm_area_struct *vma = vmf->vma;
 	int ret;
 
 	__set_current_state(TASK_RUNNING);
@@ -4092,9 +4086,9 @@ int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 		mem_cgroup_oom_enable();
 
 	if (unlikely(is_vm_hugetlb_page(vma)))
-		ret = hugetlb_fault(vma->vm_mm, vma, address, flags);
+		ret = hugetlb_fault(vma->vm_mm, vma, vmf->address, flags);
 	else
-		ret = __handle_mm_fault(vma, address, flags);
+		ret = __handle_mm_fault(vmf);
 
 	if (flags & FAULT_FLAG_USER) {
 		mem_cgroup_oom_disable();
@@ -4110,6 +4104,26 @@ int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 
 	return ret;
 }
+
+/*
+ * By the time we get here, we already hold the mm semaphore
+ *
+ * The mmap_sem may have been released depending on flags and our
+ * return value.  See filemap_fault() and __lock_page_or_retry().
+ */
+int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
+		unsigned int flags)
+{
+	struct vm_fault vmf = {
+		.vma = vma,
+		.address = address & PAGE_MASK,
+		.flags = flags,
+		.pgoff = linear_page_index(vma, address),
+		.gfp_mask = __get_fault_gfp_mask(vma),
+	};
+
+	return vm_handle_fault(&vmf);
+}
 EXPORT_SYMBOL_GPL(handle_mm_fault);
 
 #ifndef __PAGETABLE_P4D_FOLDED
Laurent Dufour Jan. 16, 2018, 2:47 p.m. UTC | #3
On 13/01/2018 05:23, Matthew Wilcox wrote:
> On Fri, Jan 12, 2018 at 11:02:51AM -0800, Matthew Wilcox wrote:
>> On Fri, Jan 12, 2018 at 06:26:06PM +0100, Laurent Dufour wrote:
>>> @@ -1354,7 +1354,10 @@ extern int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
>>>  		unsigned int flags);
>>>  #ifdef CONFIG_SPF
>>>  extern int handle_speculative_fault(struct mm_struct *mm,
>>> +				    unsigned long address, unsigned int flags,
>>> +				    struct vm_area_struct **vma);
>>
>> I think this shows that we need to create 'struct vm_fault' on the stack
>> in the arch code and then pass it to handle_speculative_fault(), followed
>> by handle_mm_fault().  That should be quite a nice cleanup actually.
>> I know that's only 30+ architectures to change ;-)
> 
> Of course, we don't need to change them all.  Try this:

That would be good candidate for a clean up but I'm not sure this should be
part of this already too long series.

If you don't mind, unless a global agreement is stated on that, I'd prefer
to postpone such a change once the initial series is accepted.

Cheers,
Laurent.

> Subject: [PATCH] Add vm_handle_fault
> 
> For the speculative fault handler, we want to create the struct vm_fault
> on the stack in the arch code and pass it into the generic mm code.
> To avoid changing 30+ architectures, leave handle_mm_fault with its
> current function signature and move its guts into the new vm_handle_fault
> function.  Even this saves a nice 172 bytes on the random x86-64 .config
> I happen to have around.
> 
> Signed-off-by: Matthew Wilcox <mawilcox@microsoft.com>
> 
> diff --git a/mm/memory.c b/mm/memory.c
> index 5eb3d2524bdc..403934297a3d 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -3977,36 +3977,28 @@ static int handle_pte_fault(struct vm_fault *vmf)
>   * The mmap_sem may have been released depending on flags and our
>   * return value.  See filemap_fault() and __lock_page_or_retry().
>   */
> -static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
> -		unsigned int flags)
> +static int __handle_mm_fault(struct vm_fault *vmf)
>  {
> -	struct vm_fault vmf = {
> -		.vma = vma,
> -		.address = address & PAGE_MASK,
> -		.flags = flags,
> -		.pgoff = linear_page_index(vma, address),
> -		.gfp_mask = __get_fault_gfp_mask(vma),
> -	};
> -	unsigned int dirty = flags & FAULT_FLAG_WRITE;
> -	struct mm_struct *mm = vma->vm_mm;
> +	unsigned int dirty = vmf->flags & FAULT_FLAG_WRITE;
> +	struct mm_struct *mm = vmf->vma->vm_mm;
>  	pgd_t *pgd;
>  	p4d_t *p4d;
>  	int ret;
> 
> -	pgd = pgd_offset(mm, address);
> -	p4d = p4d_alloc(mm, pgd, address);
> +	pgd = pgd_offset(mm, vmf->address);
> +	p4d = p4d_alloc(mm, pgd, vmf->address);
>  	if (!p4d)
>  		return VM_FAULT_OOM;
> 
> -	vmf.pud = pud_alloc(mm, p4d, address);
> -	if (!vmf.pud)
> +	vmf->pud = pud_alloc(mm, p4d, vmf->address);
> +	if (!vmf->pud)
>  		return VM_FAULT_OOM;
> -	if (pud_none(*vmf.pud) && transparent_hugepage_enabled(vma)) {
> -		ret = create_huge_pud(&vmf);
> +	if (pud_none(*vmf->pud) && transparent_hugepage_enabled(vmf->vma)) {
> +		ret = create_huge_pud(vmf);
>  		if (!(ret & VM_FAULT_FALLBACK))
>  			return ret;
>  	} else {
> -		pud_t orig_pud = *vmf.pud;
> +		pud_t orig_pud = *vmf->pud;
> 
>  		barrier();
>  		if (pud_trans_huge(orig_pud) || pud_devmap(orig_pud)) {
> @@ -4014,50 +4006,51 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
>  			/* NUMA case for anonymous PUDs would go here */
> 
>  			if (dirty && !pud_access_permitted(orig_pud, WRITE)) {
> -				ret = wp_huge_pud(&vmf, orig_pud);
> +				ret = wp_huge_pud(vmf, orig_pud);
>  				if (!(ret & VM_FAULT_FALLBACK))
>  					return ret;
>  			} else {
> -				huge_pud_set_accessed(&vmf, orig_pud);
> +				huge_pud_set_accessed(vmf, orig_pud);
>  				return 0;
>  			}
>  		}
>  	}
> 
> -	vmf.pmd = pmd_alloc(mm, vmf.pud, address);
> -	if (!vmf.pmd)
> +	vmf->pmd = pmd_alloc(mm, vmf->pud, vmf->address);
> +	if (!vmf->pmd)
>  		return VM_FAULT_OOM;
> -	if (pmd_none(*vmf.pmd) && transparent_hugepage_enabled(vma)) {
> -		ret = create_huge_pmd(&vmf);
> +	if (pmd_none(*vmf->pmd) && transparent_hugepage_enabled(vmf->vma)) {
> +		ret = create_huge_pmd(vmf);
>  		if (!(ret & VM_FAULT_FALLBACK))
>  			return ret;
>  	} else {
> -		pmd_t orig_pmd = *vmf.pmd;
> +		pmd_t orig_pmd = *vmf->pmd;
> 
>  		barrier();
>  		if (unlikely(is_swap_pmd(orig_pmd))) {
>  			VM_BUG_ON(thp_migration_supported() &&
>  					  !is_pmd_migration_entry(orig_pmd));
>  			if (is_pmd_migration_entry(orig_pmd))
> -				pmd_migration_entry_wait(mm, vmf.pmd);
> +				pmd_migration_entry_wait(mm, vmf->pmd);
>  			return 0;
>  		}
>  		if (pmd_trans_huge(orig_pmd) || pmd_devmap(orig_pmd)) {
> -			if (pmd_protnone(orig_pmd) && vma_is_accessible(vma))
> -				return do_huge_pmd_numa_page(&vmf, orig_pmd);
> +			if (pmd_protnone(orig_pmd) &&
> +						vma_is_accessible(vmf->vma))
> +				return do_huge_pmd_numa_page(vmf, orig_pmd);
> 
>  			if (dirty && !pmd_access_permitted(orig_pmd, WRITE)) {
> -				ret = wp_huge_pmd(&vmf, orig_pmd);
> +				ret = wp_huge_pmd(vmf, orig_pmd);
>  				if (!(ret & VM_FAULT_FALLBACK))
>  					return ret;
>  			} else {
> -				huge_pmd_set_accessed(&vmf, orig_pmd);
> +				huge_pmd_set_accessed(vmf, orig_pmd);
>  				return 0;
>  			}
>  		}
>  	}
> 
> -	return handle_pte_fault(&vmf);
> +	return handle_pte_fault(vmf);
>  }
> 
>  /*
> @@ -4066,9 +4059,10 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
>   * The mmap_sem may have been released depending on flags and our
>   * return value.  See filemap_fault() and __lock_page_or_retry().
>   */
> -int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
> -		unsigned int flags)
> +int vm_handle_fault(struct vm_fault *vmf)
>  {
> +	unsigned int flags = vmf->flags;
> +	struct vm_area_struct *vma = vmf->vma;
>  	int ret;
> 
>  	__set_current_state(TASK_RUNNING);
> @@ -4092,9 +4086,9 @@ int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
>  		mem_cgroup_oom_enable();
> 
>  	if (unlikely(is_vm_hugetlb_page(vma)))
> -		ret = hugetlb_fault(vma->vm_mm, vma, address, flags);
> +		ret = hugetlb_fault(vma->vm_mm, vma, vmf->address, flags);
>  	else
> -		ret = __handle_mm_fault(vma, address, flags);
> +		ret = __handle_mm_fault(vmf);
> 
>  	if (flags & FAULT_FLAG_USER) {
>  		mem_cgroup_oom_disable();
> @@ -4110,6 +4104,26 @@ int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
> 
>  	return ret;
>  }
> +
> +/*
> + * By the time we get here, we already hold the mm semaphore
> + *
> + * The mmap_sem may have been released depending on flags and our
> + * return value.  See filemap_fault() and __lock_page_or_retry().
> + */
> +int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
> +		unsigned int flags)
> +{
> +	struct vm_fault vmf = {
> +		.vma = vma,
> +		.address = address & PAGE_MASK,
> +		.flags = flags,
> +		.pgoff = linear_page_index(vma, address),
> +		.gfp_mask = __get_fault_gfp_mask(vma),
> +	};
> +
> +	return vm_handle_fault(&vmf);
> +}
>  EXPORT_SYMBOL_GPL(handle_mm_fault);
> 
>  #ifndef __PAGETABLE_P4D_FOLDED
>
Matthew Wilcox (Oracle) Jan. 16, 2018, 2:58 p.m. UTC | #4
On Tue, Jan 16, 2018 at 03:47:51PM +0100, Laurent Dufour wrote:
> On 13/01/2018 05:23, Matthew Wilcox wrote:
> > Of course, we don't need to change them all.  Try this:
> 
> That would be good candidate for a clean up but I'm not sure this should be
> part of this already too long series.
> 
> If you don't mind, unless a global agreement is stated on that, I'd prefer
> to postpone such a change once the initial series is accepted.

Actually, I think this can go in first, independently of the speculative
fault series.  It's a win in memory savings, and probably shaves a
cycle or two off the fault handler due to less argument marshalling in
the call-stack.
diff mbox series

Patch

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 4d8a7621da8a..02da17792f0d 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1354,7 +1354,10 @@  extern int handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 		unsigned int flags);
 #ifdef CONFIG_SPF
 extern int handle_speculative_fault(struct mm_struct *mm,
-				    unsigned long address, unsigned int flags);
+				    unsigned long address, unsigned int flags,
+				    struct vm_area_struct **vma);
+extern bool can_reuse_spf_vma(struct vm_area_struct *vma,
+			      unsigned long address);
 #endif /* CONFIG_SPF */
 extern int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
 			    unsigned long address, unsigned int fault_flags,
diff --git a/mm/memory.c b/mm/memory.c
index 6ccb1f45473a..e1f172ac2c90 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4284,13 +4284,22 @@  static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 /* This is required by vm_normal_page() */
 #error "Speculative page fault handler requires __HAVE_ARCH_PTE_SPECIAL"
 #endif
-
 /*
  * vm_normal_page() adds some processing which should be done while
  * hodling the mmap_sem.
  */
+
+/*
+ * Tries to handle the page fault in a speculative way, without grabbing the
+ * mmap_sem.
+ * When VM_FAULT_RETRY is returned, the vma pointer is valid and this vma must
+ * be checked later when the mmap_sem has been grabbed by calling
+ * can_reuse_spf_vma().
+ * This is needed as the returned vma is kept in memory until the call to
+ * can_reuse_spf_vma() is made.
+ */
 int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
-			     unsigned int flags)
+			     unsigned int flags, struct vm_area_struct **vma)
 {
 	struct vm_fault vmf = {
 		.address = address,
@@ -4299,7 +4308,6 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 	p4d_t *p4d, p4dval;
 	pud_t pudval;
 	int seq, ret = VM_FAULT_RETRY;
-	struct vm_area_struct *vma;
 #ifdef CONFIG_NUMA
 	struct mempolicy *pol;
 #endif
@@ -4308,14 +4316,16 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 	flags &= ~(FAULT_FLAG_ALLOW_RETRY|FAULT_FLAG_KILLABLE);
 	flags |= FAULT_FLAG_SPECULATIVE;
 
-	vma = get_vma(mm, address);
-	if (!vma)
+	*vma = get_vma(mm, address);
+	if (!*vma)
 		return ret;
+	vmf.vma = *vma;
 
-	seq = raw_read_seqcount(&vma->vm_sequence); /* rmb <-> seqlock,vma_rb_erase() */
+	/* rmb <-> seqlock,vma_rb_erase() */
+	seq = raw_read_seqcount(&vmf.vma->vm_sequence);
 	if (seq & 1) {
-		trace_spf_vma_changed(_RET_IP_, vma, address);
-		goto out_put;
+		trace_spf_vma_changed(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
 	/*
@@ -4323,9 +4333,9 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 	 * with the VMA.
 	 * This include huge page from hugetlbfs.
 	 */
-	if (vma->vm_ops) {
-		trace_spf_vma_notsup(_RET_IP_, vma, address);
-		goto out_put;
+	if (vmf.vma->vm_ops) {
+		trace_spf_vma_notsup(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
 	/*
@@ -4333,18 +4343,18 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 	 * because vm_next and vm_prev must be safe. This can't be guaranteed
 	 * in the speculative path.
 	 */
-	if (unlikely(!vma->anon_vma)) {
-		trace_spf_vma_notsup(_RET_IP_, vma, address);
-		goto out_put;
+	if (unlikely(!vmf.vma->anon_vma)) {
+		trace_spf_vma_notsup(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
-	vmf.vma_flags = READ_ONCE(vma->vm_flags);
-	vmf.vma_page_prot = READ_ONCE(vma->vm_page_prot);
+	vmf.vma_flags = READ_ONCE(vmf.vma->vm_flags);
+	vmf.vma_page_prot = READ_ONCE(vmf.vma->vm_page_prot);
 
 	/* Can't call userland page fault handler in the speculative path */
 	if (unlikely(vmf.vma_flags & VM_UFFD_MISSING)) {
-		trace_spf_vma_notsup(_RET_IP_, vma, address);
-		goto out_put;
+		trace_spf_vma_notsup(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
 	if (vmf.vma_flags & VM_GROWSDOWN || vmf.vma_flags & VM_GROWSUP) {
@@ -4353,48 +4363,39 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 		 * boundaries but we want to trace it as not supported instead
 		 * of changed.
 		 */
-		trace_spf_vma_notsup(_RET_IP_, vma, address);
-		goto out_put;
+		trace_spf_vma_notsup(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
-	if (address < READ_ONCE(vma->vm_start)
-	    || READ_ONCE(vma->vm_end) <= address) {
-		trace_spf_vma_changed(_RET_IP_, vma, address);
-		goto out_put;
+	if (address < READ_ONCE(vmf.vma->vm_start)
+	    || READ_ONCE(vmf.vma->vm_end) <= address) {
+		trace_spf_vma_changed(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
-	if (!arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE,
+	if (!arch_vma_access_permitted(vmf.vma, flags & FAULT_FLAG_WRITE,
 				       flags & FAULT_FLAG_INSTRUCTION,
-				       flags & FAULT_FLAG_REMOTE)) {
-		trace_spf_vma_access(_RET_IP_, vma, address);
-		ret = VM_FAULT_SIGSEGV;
-		goto out_put;
-	}
+				       flags & FAULT_FLAG_REMOTE))
+		goto out_segv;
 
 	/* This is one is required to check that the VMA has write access set */
 	if (flags & FAULT_FLAG_WRITE) {
-		if (unlikely(!(vmf.vma_flags & VM_WRITE))) {
-			trace_spf_vma_access(_RET_IP_, vma, address);
-			ret = VM_FAULT_SIGSEGV;
-			goto out_put;
-		}
-	} else if (unlikely(!(vmf.vma_flags & (VM_READ|VM_EXEC|VM_WRITE)))) {
-		trace_spf_vma_access(_RET_IP_, vma, address);
-		ret = VM_FAULT_SIGSEGV;
-		goto out_put;
-	}
+		if (unlikely(!(vmf.vma_flags & VM_WRITE)))
+			goto out_segv;
+	} else if (unlikely(!(vmf.vma_flags & (VM_READ|VM_EXEC|VM_WRITE))))
+		goto out_segv;
 
 #ifdef CONFIG_NUMA
 	/*
 	 * MPOL_INTERLEAVE implies additional check in mpol_misplaced() which
 	 * are not compatible with the speculative page fault processing.
 	 */
-	pol = __get_vma_policy(vma, address);
+	pol = __get_vma_policy(vmf.vma, address);
 	if (!pol)
 		pol = get_task_policy(current);
 	if (pol && pol->mode == MPOL_INTERLEAVE) {
-		trace_spf_vma_notsup(_RET_IP_, vma, address);
-		goto out_put;
+		trace_spf_vma_notsup(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 #endif
 
@@ -4456,9 +4457,8 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 		vmf.pte = NULL;
 	}
 
-	vmf.vma = vma;
-	vmf.pgoff = linear_page_index(vma, address);
-	vmf.gfp_mask = __get_fault_gfp_mask(vma);
+	vmf.pgoff = linear_page_index(vmf.vma, address);
+	vmf.gfp_mask = __get_fault_gfp_mask(vmf.vma);
 	vmf.sequence = seq;
 	vmf.flags = flags;
 
@@ -4468,16 +4468,22 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 	 * We need to re-validate the VMA after checking the bounds, otherwise
 	 * we might have a false positive on the bounds.
 	 */
-	if (read_seqcount_retry(&vma->vm_sequence, seq)) {
-		trace_spf_vma_changed(_RET_IP_, vma, address);
-		goto out_put;
+	if (read_seqcount_retry(&vmf.vma->vm_sequence, seq)) {
+		trace_spf_vma_changed(_RET_IP_, vmf.vma, address);
+		return ret;
 	}
 
 	mem_cgroup_oom_enable();
 	ret = handle_pte_fault(&vmf);
 	mem_cgroup_oom_disable();
 
-	put_vma(vma);
+	/*
+	 * If there is no need to retry, don't return the vma to the caller.
+	 */
+	if (!(ret & VM_FAULT_RETRY)) {
+		put_vma(vmf.vma);
+		*vma = NULL;
+	}
 
 	/*
 	 * The task may have entered a memcg OOM situation but
@@ -4490,9 +4496,35 @@  int handle_speculative_fault(struct mm_struct *mm, unsigned long address,
 	return ret;
 
 out_walk:
-	trace_spf_vma_notsup(_RET_IP_, vma, address);
+	trace_spf_vma_notsup(_RET_IP_, vmf.vma, address);
 	local_irq_enable();
-out_put:
+	return ret;
+
+out_segv:
+	trace_spf_vma_access(_RET_IP_, vmf.vma, address);
+	/*
+	 * We don't return VM_FAULT_RETRY so the caller is not expected to
+	 * retrieve the fetched VMA.
+	 */
+	put_vma(vmf.vma);
+	*vma = NULL;
+	return VM_FAULT_SIGSEGV;
+}
+
+/*
+ * This is used to know if the vma fetch in the speculative page fault handler
+ * is still valid when trying the regular fault path while holding the
+ * mmap_sem.
+ * The call to put_vma(vma) must be made after checking the vma's fields, as
+ * the vma may be freed by put_vma(). In such a case it is expected that false
+ * is returned.
+ */
+bool can_reuse_spf_vma(struct vm_area_struct *vma, unsigned long address)
+{
+	bool ret;
+
+	ret = !RB_EMPTY_NODE(&vma->vm_rb) &&
+		vma->vm_start <= address && address < vma->vm_end;
 	put_vma(vma);
 	return ret;
 }