diff mbox

[RFC] dax, ext2, ext4, XFS: fix data corruption race

Message ID 1453503971-5319-1-git-send-email-ross.zwisler@linux.intel.com
State New, archived
Headers show

Commit Message

Ross Zwisler Jan. 22, 2016, 11:06 p.m. UTC
With the current DAX code the following race exists:

Process 1                	Process 2
---------			---------

__dax_fault() - read file f, index 0
  get_block() -> returns hole
                             	__dax_fault() - write file f, index 0
                                  get_block() -> allocates blocks
                                  dax_insert_mapping()
  dax_load_hole()
  *data corruption*

An analogous race exists between __dax_fault() loading a hole and
__dax_pmd_fault() allocating a PMD DAX page and trying to insert it, and
that race also ends in data corruption.

One solution to this race was proposed by Jan Kara:

  So we need some exclusion that makes sure pgoff->block mapping
  information is uptodate at the moment we insert it into page tables. The
  simplest reasonably fast thing I can see is:

  When handling a read fault, things stay as is and filesystem protects the
  fault with an equivalent of EXT4_I(inode)->i_mmap_sem held for reading.
  When handling a write fault we first grab EXT4_I(inode)->i_mmap_sem for
  reading and try a read fault. If __dax_fault() sees a hole returned from
  get_blocks() during a write fault, it bails out. Filesystem grabs
  EXT4_I(inode)->i_mmap_sem for writing and retries with different
  get_blocks() callback which will allocate blocks. That way we get proper
  exclusion for faults needing to allocate blocks.

This patch adds this logic to DAX, ext2, ext4 and XFS.  The changes for
these four components all appear in the same patch as opposed to being
spread out among multiple patches in a series because we are changing the
argument list to __dax_fault(), __dax_pmd_fault() and __dax_mkwrite().
This means we can't easily change things one component at a time and still
keep the series bisectable.

For ext4 this patch assumes that the journal entry is only needed when we
are actually allocating blocks with get_block().  An in-depth review of
this logic would be welcome.

I also fixed a bug in the ext4 implementation of ext4_dax_mkwrite().
Previously it assumed that the block allocation was already complete, so it
didn't create a journal entry.  For a read that creates a zero page to
cover a hole followed by a write that actually allocates storage, this is
incorrect.  The ext4_dax_mkwrite() -> __dax_mkwrite() -> __dax_fault() path
would call get_blocks() to allocate storage, so I'm pretty sure we need a
journal entry here.

With that fixed, I noticed that for both ext2 and ext4 the paths through
the .fault and .page_mkwrite vmops paths were exactly the same, so I
removed ext4_dax_mkwrite() and ext2_dax_mkwrite() and just use
ext4_dax_fault() and ext2_dax_fault() directly instead.

I'm still in the process of testing this patch, which is part of the reason
why it is marked as RFC.  I know of at least one deadlock that is easily
hit by doing a read of a hole followed by a write which allocates storage.
If you're using xfstests you can hit this easily with generic/075 with any
of the three filesytems.  I'll continue to track this down, but I wanted to
send out this RFC to sanity check the general approach.

Signed-off-by: Ross Zwisler <ross.zwisler@linux.intel.com>
Suggested-by: Jan Kara <jack@suse.cz>
---
 fs/block_dev.c      | 19 ++++++++++--
 fs/dax.c            | 20 ++++++++++---
 fs/ext2/file.c      | 41 ++++++++++++-------------
 fs/ext4/file.c      | 86 +++++++++++++++++++++++++----------------------------
 fs/xfs/xfs_file.c   | 28 +++++++++++++----
 include/linux/dax.h |  8 +++--
 6 files changed, 121 insertions(+), 81 deletions(-)

Comments

Matthew Wilcox Jan. 23, 2016, 2:01 a.m. UTC | #1
On Fri, Jan 22, 2016 at 04:06:11PM -0700, Ross Zwisler wrote:
> +++ b/fs/block_dev.c
> @@ -1733,13 +1733,28 @@ static const struct address_space_operations def_blk_aops = {
>   */
>  static int blkdev_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
>  {
> -	return __dax_fault(vma, vmf, blkdev_get_block, NULL);
> +	int ret;
> +
> +	ret = __dax_fault(vma, vmf, blkdev_get_block, NULL, false);
> +
> +	if (WARN_ON_ONCE(ret == -EAGAIN))
> +		ret = VM_FAULT_SIGBUS;
> +
> +	return ret;
>  }

Let's not mix up -E returns and VM_FAULT returns.  We already have a
perfectly good VM_FAULT return value -- VM_FAULT_RETRY.
--
To unsubscribe from this list: send the line "unsubscribe linux-ext4" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Dave Chinner Jan. 24, 2016, 10:01 p.m. UTC | #2
On Fri, Jan 22, 2016 at 04:06:11PM -0700, Ross Zwisler wrote:
> With the current DAX code the following race exists:
> 
> Process 1                	Process 2
> ---------			---------
> 
> __dax_fault() - read file f, index 0
>   get_block() -> returns hole
>                              	__dax_fault() - write file f, index 0
>                                   get_block() -> allocates blocks
>                                   dax_insert_mapping()
>   dax_load_hole()
>   *data corruption*
> 
> An analogous race exists between __dax_fault() loading a hole and
> __dax_pmd_fault() allocating a PMD DAX page and trying to insert it, and
> that race also ends in data corruption.

Ok, so why doesn't this problem exist for the normal page cache
insertion case with concurrent read vs write faults?  It's because
the write fault first does a read fault and so always the write
fault always has a page in the radix tree for the get_block call
that allocates the extents, right?

And DAX has an optimisation in the page fault part where it skips
the read fault part of the write fault?  And so essentially the DAX
write fault is missing the object (page lock of page in the radix
tree) that the non-DAX write fault uses to avoid this problem?

What happens if we get rid of that DAX write fault optimisation that
skips the initial read fault? The write fault will always run on a
mapping that has a hole loaded, right?, so the race between
dax_load_hole() and dax_insert_mapping() goes away, because nothing
will be calling dax_load_hole() once the write fault is allocating
blocks....

> One solution to this race was proposed by Jan Kara:
> 
>   So we need some exclusion that makes sure pgoff->block mapping
>   information is uptodate at the moment we insert it into page tables. The
>   simplest reasonably fast thing I can see is:
> 
>   When handling a read fault, things stay as is and filesystem protects the
>   fault with an equivalent of EXT4_I(inode)->i_mmap_sem held for reading.
>   When handling a write fault we first grab EXT4_I(inode)->i_mmap_sem for
>   reading and try a read fault. If __dax_fault() sees a hole returned from
>   get_blocks() during a write fault, it bails out. Filesystem grabs
>   EXT4_I(inode)->i_mmap_sem for writing and retries with different
>   get_blocks() callback which will allocate blocks. That way we get proper
>   exclusion for faults needing to allocate blocks.
> 
> This patch adds this logic to DAX, ext2, ext4 and XFS.

It's too ugly to live. It hacks around a special DAX optimisation in
the fault code by adding special case locking to the filesystems,
and adds a siginificant new locking constraint to the page fault
path.

If the write fault first goes through the read fault path and loads
the hole, this race condition simply does not exist. I'd suggest
that we get rid of the DAX optimisation that skips read fault
processing on write fault so that this problem simply goes away.
Yes, it means write faults on holes will be a little slower (which,
quite frankly, I don't care at all about), but it means we don't
need to hack special cases into code that should not have to care
about various different types of page fault races. Correctness
first, speed later.

FWIW, this also means we can get rid of the hacks in the filesystem
code where we have to handle write faults through the ->fault
handler rather than the ->page_mkwrite handler.

Cheers,

Dave.
Jan Kara Jan. 25, 2016, 1:59 p.m. UTC | #3
On Mon 25-01-16 09:01:07, Dave Chinner wrote:
> On Fri, Jan 22, 2016 at 04:06:11PM -0700, Ross Zwisler wrote:
> > With the current DAX code the following race exists:
> > 
> > Process 1                	Process 2
> > ---------			---------
> > 
> > __dax_fault() - read file f, index 0
> >   get_block() -> returns hole
> >                              	__dax_fault() - write file f, index 0
> >                                   get_block() -> allocates blocks
> >                                   dax_insert_mapping()
> >   dax_load_hole()
> >   *data corruption*
> > 
> > An analogous race exists between __dax_fault() loading a hole and
> > __dax_pmd_fault() allocating a PMD DAX page and trying to insert it, and
> > that race also ends in data corruption.
> 
> Ok, so why doesn't this problem exist for the normal page cache
> insertion case with concurrent read vs write faults?  It's because
> the write fault first does a read fault and so always the write
> fault always has a page in the radix tree for the get_block call
> that allocates the extents, right?

Yes, any fault (read or write) has a page to lock which avoids races for
normal fault path.

> And DAX has an optimisation in the page fault part where it skips
> the read fault part of the write fault?  And so essentially the DAX
> write fault is missing the object (page lock of page in the radix
> tree) that the non-DAX write fault uses to avoid this problem?
> 
> What happens if we get rid of that DAX write fault optimisation that
> skips the initial read fault? The write fault will always run on a
> mapping that has a hole loaded, right?, so the race between
> dax_load_hole() and dax_insert_mapping() goes away, because nothing
> will be calling dax_load_hole() once the write fault is allocating
> blocks....

So frankly I don't like mixing of page locks into the DAX fault locking.
Also your scheme would require more tricks to deal with races between PMD
write faults racing with PTE read faults since you don't want to require
2MB worth of hole-pages to be able to do a PMD write fault. Transparent
huge pages deal with this issue using compound pages but I'd like to avoid
that horror in the DAX path...

> > One solution to this race was proposed by Jan Kara:
> > 
> >   So we need some exclusion that makes sure pgoff->block mapping
> >   information is uptodate at the moment we insert it into page tables. The
> >   simplest reasonably fast thing I can see is:
> > 
> >   When handling a read fault, things stay as is and filesystem protects the
> >   fault with an equivalent of EXT4_I(inode)->i_mmap_sem held for reading.
> >   When handling a write fault we first grab EXT4_I(inode)->i_mmap_sem for
> >   reading and try a read fault. If __dax_fault() sees a hole returned from
> >   get_blocks() during a write fault, it bails out. Filesystem grabs
> >   EXT4_I(inode)->i_mmap_sem for writing and retries with different
> >   get_blocks() callback which will allocate blocks. That way we get proper
> >   exclusion for faults needing to allocate blocks.
> > 
> > This patch adds this logic to DAX, ext2, ext4 and XFS.
> 
> It's too ugly to live. It hacks around a special DAX optimisation in
> the fault code by adding special case locking to the filesystems,
> and adds a siginificant new locking constraint to the page fault
> path.
> 
> If the write fault first goes through the read fault path and loads
> the hole, this race condition simply does not exist. I'd suggest
> that we get rid of the DAX optimisation that skips read fault
> processing on write fault so that this problem simply goes away.
> Yes, it means write faults on holes will be a little slower (which,
> quite frankly, I don't care at all about), but it means we don't
> need to hack special cases into code that should not have to care
> about various different types of page fault races. Correctness
> first, speed later.
> 
> FWIW, this also means we can get rid of the hacks in the filesystem
> code where we have to handle write faults through the ->fault
> handler rather than the ->page_mkwrite handler.

So I don't mind doing read-fault first. But as I wrote above I don't think
it solves all the issues. The rule I wanted to introduce is:

1) Fault requiring block allocation require exclusive lock held over the
whole fault.

2) Fault not requiring allocation is enough with shared lock held over the
whole fault.

So we can certainly grab i_mmap_sem exclusively whenever we hit a write
fault and it would be correct. This looks like a quite clean design to me.

As a performance optimization (which is upto each filesystem) we can try to
satisfy write fault without allocation while holding i_mmap_sem in a shared
mode and if that fails, grab the lock exclusively and retry. Still don't like
it?

								Honza
Matthew Wilcox Jan. 25, 2016, 8:46 p.m. UTC | #4
On Mon, Jan 25, 2016 at 09:01:07AM +1100, Dave Chinner wrote:
> On Fri, Jan 22, 2016 at 04:06:11PM -0700, Ross Zwisler wrote:
> > With the current DAX code the following race exists:
> > 
> > Process 1                	Process 2
> > ---------			---------
> > 
> > __dax_fault() - read file f, index 0
> >   get_block() -> returns hole
> >                              	__dax_fault() - write file f, index 0
> >                                   get_block() -> allocates blocks
> >                                   dax_insert_mapping()
> >   dax_load_hole()
> >   *data corruption*
> > 
> > An analogous race exists between __dax_fault() loading a hole and
> > __dax_pmd_fault() allocating a PMD DAX page and trying to insert it, and
> > that race also ends in data corruption.
> 
> Ok, so why doesn't this problem exist for the normal page cache
> insertion case with concurrent read vs write faults?  It's because
> the write fault first does a read fault and so always the write
> fault always has a page in the radix tree for the get_block call
> that allocates the extents, right?

No, it's because allocation of blocks is separated from allocation of
struct page.

> And DAX has an optimisation in the page fault part where it skips
> the read fault part of the write fault?  And so essentially the DAX
> write fault is missing the object (page lock of page in the radix
> tree) that the non-DAX write fault uses to avoid this problem?
>
> What happens if we get rid of that DAX write fault optimisation that
> skips the initial read fault? The write fault will always run on a
> mapping that has a hole loaded, right?, so the race between
> dax_load_hole() and dax_insert_mapping() goes away, because nothing
> will be calling dax_load_hole() once the write fault is allocating
> blocks....

So in your proposal, we'd look in the radix tree, find nothing,
call get_block(..., 0).  If we get something back, we can insert it.
If we hit a hole, we allocate a struct page, put it in the radix tree
and return to user space.  If that was a write fault after all, it'll
come back to us through the ->page_mkwrite handler where we can take the
page lock on the allocated struct page, then call down to DAX which calls
back through get_block to allocate?  Then DAX kicks the struct page out
of the page cache and frees it.

That seems to work to me.  And we can get rid of pfn_mkwrite at the same
time which seems like a win to me.

--
To unsubscribe from this list: send the line "unsubscribe linux-ext4" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Jan Kara Jan. 26, 2016, 8:46 a.m. UTC | #5
On Mon 25-01-16 15:46:36, Matthew Wilcox wrote:
> On Mon, Jan 25, 2016 at 09:01:07AM +1100, Dave Chinner wrote:
> > On Fri, Jan 22, 2016 at 04:06:11PM -0700, Ross Zwisler wrote:
> > > With the current DAX code the following race exists:
> > > 
> > > Process 1                	Process 2
> > > ---------			---------
> > > 
> > > __dax_fault() - read file f, index 0
> > >   get_block() -> returns hole
> > >                              	__dax_fault() - write file f, index 0
> > >                                   get_block() -> allocates blocks
> > >                                   dax_insert_mapping()
> > >   dax_load_hole()
> > >   *data corruption*
> > > 
> > > An analogous race exists between __dax_fault() loading a hole and
> > > __dax_pmd_fault() allocating a PMD DAX page and trying to insert it, and
> > > that race also ends in data corruption.
> > 
> > Ok, so why doesn't this problem exist for the normal page cache
> > insertion case with concurrent read vs write faults?  It's because
> > the write fault first does a read fault and so always the write
> > fault always has a page in the radix tree for the get_block call
> > that allocates the extents, right?
> 
> No, it's because allocation of blocks is separated from allocation of
> struct page.
> 
> > And DAX has an optimisation in the page fault part where it skips
> > the read fault part of the write fault?  And so essentially the DAX
> > write fault is missing the object (page lock of page in the radix
> > tree) that the non-DAX write fault uses to avoid this problem?
> >
> > What happens if we get rid of that DAX write fault optimisation that
> > skips the initial read fault? The write fault will always run on a
> > mapping that has a hole loaded, right?, so the race between
> > dax_load_hole() and dax_insert_mapping() goes away, because nothing
> > will be calling dax_load_hole() once the write fault is allocating
> > blocks....
> 
> So in your proposal, we'd look in the radix tree, find nothing,
> call get_block(..., 0).  If we get something back, we can insert it.
> If we hit a hole, we allocate a struct page, put it in the radix tree
> and return to user space.  If that was a write fault after all, it'll
> come back to us through the ->page_mkwrite handler where we can take the
> page lock on the allocated struct page, then call down to DAX which calls
> back through get_block to allocate?  Then DAX kicks the struct page out
> of the page cache and frees it.
> 
> That seems to work to me.  And we can get rid of pfn_mkwrite at the same
> time which seems like a win to me.

Getting rid of pfn_mkwrite() would be nice, I agree. But the above scheme
still has issues when PMD pages come into play. Or maybe to start from the
beginning: How would you like PMD faults to work? Because there is no
obvious 'struct page' to protect the allocation of 2MB worth of blocks. You
could resort to similar tricks like transparent huge pages do (compound
pages) but then the cure is IMHO worse than the disease.

There is one option: No allocation of blocks for PMD faults. That would
make code much simpler and solve the races but I'm not sure whether that is
really an acceptable loss of functionality...

								Honza
Matthew Wilcox Jan. 26, 2016, 12:48 p.m. UTC | #6
On Mon, Jan 25, 2016 at 02:59:21PM +0100, Jan Kara wrote:
> On Mon 25-01-16 09:01:07, Dave Chinner wrote:
> > What happens if we get rid of that DAX write fault optimisation that
> > skips the initial read fault? The write fault will always run on a
> > mapping that has a hole loaded, right?, so the race between
> > dax_load_hole() and dax_insert_mapping() goes away, because nothing
> > will be calling dax_load_hole() once the write fault is allocating
> > blocks....
> 
> So frankly I don't like mixing of page locks into the DAX fault locking.
> Also your scheme would require more tricks to deal with races between PMD
> write faults racing with PTE read faults since you don't want to require
> 2MB worth of hole-pages to be able to do a PMD write fault. Transparent
> huge pages deal with this issue using compound pages but I'd like to avoid
> that horror in the DAX path...

I *think* that what Dave's proposing (and if he isn't, I'm proposing it
for him) is that the filesystem takes its allocation lock shared during
the ->fault handler, then in the ->page_mkwrite handler, it knows that an
allocation is coming, so it takes its allocation lock in exclusive mode.

So read vs write faults won't be able to race because the allocation lock
will prevent it.

--
To unsubscribe from this list: send the line "unsubscribe linux-ext4" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
Jan Kara Jan. 26, 2016, 1:05 p.m. UTC | #7
On Tue 26-01-16 07:48:12, Matthew Wilcox wrote:
> On Mon, Jan 25, 2016 at 02:59:21PM +0100, Jan Kara wrote:
> > On Mon 25-01-16 09:01:07, Dave Chinner wrote:
> > > What happens if we get rid of that DAX write fault optimisation that
> > > skips the initial read fault? The write fault will always run on a
> > > mapping that has a hole loaded, right?, so the race between
> > > dax_load_hole() and dax_insert_mapping() goes away, because nothing
> > > will be calling dax_load_hole() once the write fault is allocating
> > > blocks....
> > 
> > So frankly I don't like mixing of page locks into the DAX fault locking.
> > Also your scheme would require more tricks to deal with races between PMD
> > write faults racing with PTE read faults since you don't want to require
> > 2MB worth of hole-pages to be able to do a PMD write fault. Transparent
> > huge pages deal with this issue using compound pages but I'd like to avoid
> > that horror in the DAX path...
> 
> I *think* that what Dave's proposing (and if he isn't, I'm proposing it
> for him) is that the filesystem takes its allocation lock shared during
> the ->fault handler, then in the ->page_mkwrite handler, it knows that an
> allocation is coming, so it takes its allocation lock in exclusive mode.
> 
> So read vs write faults won't be able to race because the allocation lock
> will prevent it.

So this is correct and clean design but we will take the lock in exclusive
mode (and thus hurt scalability) for every write fault, not just for the
ones allocating blocks. And at the moment we take exclusive lock for write
faults, there's no more need for having the hole page instantiated - we can
still do it for simplicity but it's no longer necessary to avoid data
corruption.

								Honza
Matthew Wilcox Jan. 26, 2016, 2:47 p.m. UTC | #8
On Tue, Jan 26, 2016 at 02:05:21PM +0100, Jan Kara wrote:
> On Tue 26-01-16 07:48:12, Matthew Wilcox wrote:
> > I *think* that what Dave's proposing (and if he isn't, I'm proposing it
> > for him) is that the filesystem takes its allocation lock shared during
> > the ->fault handler, then in the ->page_mkwrite handler, it knows that an
> > allocation is coming, so it takes its allocation lock in exclusive mode.
> > 
> > So read vs write faults won't be able to race because the allocation lock
> > will prevent it.
> 
> So this is correct and clean design but we will take the lock in exclusive
> mode (and thus hurt scalability) for every write fault, not just for the
> ones allocating blocks. And at the moment we take exclusive lock for write
> faults, there's no more need for having the hole page instantiated - we can
> still do it for simplicity but it's no longer necessary to avoid data
> corruption.

In my mind we take it only for allocating writes, because we also include
the patch to insert PFNs with the writable bit set in the dax_fault
handler if the page fault was for writes.

Although that only works when the *first* fault is a write ... if we
read and page then write the same page, we will indeed take the lock
in exclusive mode.  I think that's fixable too -- in the page_mkwrite
handler, take the lock in exclusive mode only if there's a page in the
radix tree.  I'll take a look at that optimisation after doing the first
couple of steps.
--
To unsubscribe from this list: send the line "unsubscribe linux-ext4" in
the body of a message to majordomo@vger.kernel.org
More majordomo info at  http://vger.kernel.org/majordomo-info.html
diff mbox

Patch

diff --git a/fs/block_dev.c b/fs/block_dev.c
index 303b7cd..775f1b0 100644
--- a/fs/block_dev.c
+++ b/fs/block_dev.c
@@ -1733,13 +1733,28 @@  static const struct address_space_operations def_blk_aops = {
  */
 static int blkdev_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-	return __dax_fault(vma, vmf, blkdev_get_block, NULL);
+	int ret;
+
+	ret = __dax_fault(vma, vmf, blkdev_get_block, NULL, false);
+
+	if (WARN_ON_ONCE(ret == -EAGAIN))
+		ret = VM_FAULT_SIGBUS;
+
+	return ret;
 }
 
 static int blkdev_dax_pmd_fault(struct vm_area_struct *vma, unsigned long addr,
 		pmd_t *pmd, unsigned int flags)
 {
-	return __dax_pmd_fault(vma, addr, pmd, flags, blkdev_get_block, NULL);
+	int ret;
+
+	ret = __dax_pmd_fault(vma, addr, pmd, flags, blkdev_get_block, NULL,
+			false);
+
+	if (WARN_ON_ONCE(ret == -EAGAIN))
+		ret = VM_FAULT_SIGBUS;
+
+	return ret;
 }
 
 static void blkdev_vm_open(struct vm_area_struct *vma)
diff --git a/fs/dax.c b/fs/dax.c
index 206650f..7a927eb 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -582,13 +582,19 @@  static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
  *	extent mappings from @get_block, but it is optional for reads as
  *	dax_insert_mapping() will always zero unwritten blocks. If the fs does
  *	not support unwritten extents, the it should pass NULL.
+ * @alloc_ok: True if our caller is holding a lock that isolates us from other
+ *	DAX faults on the same inode.  This allows us to allocate new storage
+ *	with get_block() and not have to worry about races with other fault
+ *	handlers.  If this is unset and we need to allocate storage we will
+ *	return -EAGAIN to ask our caller to retry with the proper locking.
  *
  * When a page fault occurs, filesystems may call this helper in their
  * fault handler for DAX files. __dax_fault() assumes the caller has done all
  * the necessary locking for the page fault to proceed successfully.
  */
 int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
-			get_block_t get_block, dax_iodone_t complete_unwritten)
+			get_block_t get_block, dax_iodone_t complete_unwritten,
+			bool alloc_ok)
 {
 	struct file *file = vma->vm_file;
 	struct address_space *mapping = file->f_mapping;
@@ -642,6 +648,9 @@  int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 
 	if (!buffer_mapped(&bh) && !buffer_unwritten(&bh) && !vmf->cow_page) {
 		if (vmf->flags & FAULT_FLAG_WRITE) {
+			if (!alloc_ok)
+				return -EAGAIN;
+
 			error = get_block(inode, block, &bh, 1);
 			count_vm_event(PGMAJFAULT);
 			mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
@@ -745,7 +754,7 @@  int dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		sb_start_pagefault(sb);
 		file_update_time(vma->vm_file);
 	}
-	result = __dax_fault(vma, vmf, get_block, complete_unwritten);
+	result = __dax_fault(vma, vmf, get_block, complete_unwritten, false);
 	if (vmf->flags & FAULT_FLAG_WRITE)
 		sb_end_pagefault(sb);
 
@@ -780,7 +789,7 @@  static void __dax_dbg(struct buffer_head *bh, unsigned long address,
 
 int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 		pmd_t *pmd, unsigned int flags, get_block_t get_block,
-		dax_iodone_t complete_unwritten)
+		dax_iodone_t complete_unwritten, bool alloc_ok)
 {
 	struct file *file = vma->vm_file;
 	struct address_space *mapping = file->f_mapping;
@@ -836,6 +845,9 @@  int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 		return VM_FAULT_SIGBUS;
 
 	if (!buffer_mapped(&bh) && write) {
+		if (!alloc_ok)
+			return -EAGAIN;
+
 		if (get_block(inode, block, &bh, 1) != 0)
 			return VM_FAULT_SIGBUS;
 		alloc = true;
@@ -1017,7 +1029,7 @@  int dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 		file_update_time(vma->vm_file);
 	}
 	result = __dax_pmd_fault(vma, address, pmd, flags, get_block,
-				complete_unwritten);
+				complete_unwritten, false);
 	if (flags & FAULT_FLAG_WRITE)
 		sb_end_pagefault(sb);
 
diff --git a/fs/ext2/file.c b/fs/ext2/file.c
index 2c88d68..1106a9e 100644
--- a/fs/ext2/file.c
+++ b/fs/ext2/file.c
@@ -49,11 +49,17 @@  static int ext2_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 		sb_start_pagefault(inode->i_sb);
 		file_update_time(vma->vm_file);
 	}
+
 	down_read(&ei->dax_sem);
+	ret = __dax_fault(vma, vmf, ext2_get_block, NULL, false);
+	up_read(&ei->dax_sem);
 
-	ret = __dax_fault(vma, vmf, ext2_get_block, NULL);
+	if (ret == -EAGAIN) {
+		down_write(&ei->dax_sem);
+		ret = __dax_fault(vma, vmf, ext2_get_block, NULL, true);
+		up_write(&ei->dax_sem);
+	}
 
-	up_read(&ei->dax_sem);
 	if (vmf->flags & FAULT_FLAG_WRITE)
 		sb_end_pagefault(inode->i_sb);
 	return ret;
@@ -70,33 +76,24 @@  static int ext2_dax_pmd_fault(struct vm_area_struct *vma, unsigned long addr,
 		sb_start_pagefault(inode->i_sb);
 		file_update_time(vma->vm_file);
 	}
+
 	down_read(&ei->dax_sem);
+	ret = __dax_pmd_fault(vma, addr, pmd, flags, ext2_get_block, NULL,
+			false);
+	up_read(&ei->dax_sem);
 
-	ret = __dax_pmd_fault(vma, addr, pmd, flags, ext2_get_block, NULL);
+	if (ret == -EAGAIN) {
+		down_write(&ei->dax_sem);
+		ret = __dax_pmd_fault(vma, addr, pmd, flags, ext2_get_block,
+				NULL, true);
+		up_write(&ei->dax_sem);
+	}
 
-	up_read(&ei->dax_sem);
 	if (flags & FAULT_FLAG_WRITE)
 		sb_end_pagefault(inode->i_sb);
 	return ret;
 }
 
-static int ext2_dax_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-	struct inode *inode = file_inode(vma->vm_file);
-	struct ext2_inode_info *ei = EXT2_I(inode);
-	int ret;
-
-	sb_start_pagefault(inode->i_sb);
-	file_update_time(vma->vm_file);
-	down_read(&ei->dax_sem);
-
-	ret = __dax_mkwrite(vma, vmf, ext2_get_block, NULL);
-
-	up_read(&ei->dax_sem);
-	sb_end_pagefault(inode->i_sb);
-	return ret;
-}
-
 static int ext2_dax_pfn_mkwrite(struct vm_area_struct *vma,
 		struct vm_fault *vmf)
 {
@@ -124,7 +121,7 @@  static int ext2_dax_pfn_mkwrite(struct vm_area_struct *vma,
 static const struct vm_operations_struct ext2_dax_vm_ops = {
 	.fault		= ext2_dax_fault,
 	.pmd_fault	= ext2_dax_pmd_fault,
-	.page_mkwrite	= ext2_dax_mkwrite,
+	.page_mkwrite	= ext2_dax_fault,
 	.pfn_mkwrite	= ext2_dax_pfn_mkwrite,
 };
 
diff --git a/fs/ext4/file.c b/fs/ext4/file.c
index fa899c9..abddc8a 100644
--- a/fs/ext4/file.c
+++ b/fs/ext4/file.c
@@ -204,24 +204,30 @@  static int ext4_dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 	if (write) {
 		sb_start_pagefault(sb);
 		file_update_time(vma->vm_file);
-		down_read(&EXT4_I(inode)->i_mmap_sem);
-		handle = ext4_journal_start_sb(sb, EXT4_HT_WRITE_PAGE,
-						EXT4_DATA_TRANS_BLOCKS(sb));
-	} else
-		down_read(&EXT4_I(inode)->i_mmap_sem);
+	}
 
-	if (IS_ERR(handle))
-		result = VM_FAULT_SIGBUS;
-	else
-		result = __dax_fault(vma, vmf, ext4_dax_mmap_get_block, NULL);
+	down_read(&EXT4_I(inode)->i_mmap_sem);
+	result = __dax_fault(vma, vmf, ext4_dax_mmap_get_block, NULL,
+			false);
+	up_read(&EXT4_I(inode)->i_mmap_sem);
 
-	if (write) {
-		if (!IS_ERR(handle))
+	if (result == -EAGAIN) {
+		down_write(&EXT4_I(inode)->i_mmap_sem);
+		handle = ext4_journal_start_sb(sb, EXT4_HT_WRITE_PAGE,
+				EXT4_DATA_TRANS_BLOCKS(sb));
+
+		if (IS_ERR(handle))
+			result = VM_FAULT_SIGBUS;
+		else {
+			result = __dax_fault(vma, vmf,
+					ext4_dax_mmap_get_block, NULL, true);
 			ext4_journal_stop(handle);
-		up_read(&EXT4_I(inode)->i_mmap_sem);
+		}
+		up_write(&EXT4_I(inode)->i_mmap_sem);
+	}
+
+	if (write)
 		sb_end_pagefault(sb);
-	} else
-		up_read(&EXT4_I(inode)->i_mmap_sem);
 
 	return result;
 }
@@ -238,47 +244,37 @@  static int ext4_dax_pmd_fault(struct vm_area_struct *vma, unsigned long addr,
 	if (write) {
 		sb_start_pagefault(sb);
 		file_update_time(vma->vm_file);
-		down_read(&EXT4_I(inode)->i_mmap_sem);
+	}
+
+	down_read(&EXT4_I(inode)->i_mmap_sem);
+	result = __dax_pmd_fault(vma, addr, pmd, flags,
+			ext4_dax_mmap_get_block, NULL, false);
+	up_read(&EXT4_I(inode)->i_mmap_sem);
+
+	if (result == -EAGAIN) {
+		down_write(&EXT4_I(inode)->i_mmap_sem);
 		handle = ext4_journal_start_sb(sb, EXT4_HT_WRITE_PAGE,
 				ext4_chunk_trans_blocks(inode,
 							PMD_SIZE / PAGE_SIZE));
-	} else
-		down_read(&EXT4_I(inode)->i_mmap_sem);
 
-	if (IS_ERR(handle))
-		result = VM_FAULT_SIGBUS;
-	else
-		result = __dax_pmd_fault(vma, addr, pmd, flags,
-				ext4_dax_mmap_get_block, NULL);
-
-	if (write) {
-		if (!IS_ERR(handle))
+		if (IS_ERR(handle))
+			result = VM_FAULT_SIGBUS;
+		else {
+			result = __dax_pmd_fault(vma, addr, pmd, flags,
+					ext4_dax_mmap_get_block, NULL, true);
 			ext4_journal_stop(handle);
-		up_read(&EXT4_I(inode)->i_mmap_sem);
+		}
+		up_write(&EXT4_I(inode)->i_mmap_sem);
+	}
+
+	if (write)
 		sb_end_pagefault(sb);
-	} else
-		up_read(&EXT4_I(inode)->i_mmap_sem);
 
 	return result;
 }
 
-static int ext4_dax_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-	int err;
-	struct inode *inode = file_inode(vma->vm_file);
-
-	sb_start_pagefault(inode->i_sb);
-	file_update_time(vma->vm_file);
-	down_read(&EXT4_I(inode)->i_mmap_sem);
-	err = __dax_mkwrite(vma, vmf, ext4_dax_mmap_get_block, NULL);
-	up_read(&EXT4_I(inode)->i_mmap_sem);
-	sb_end_pagefault(inode->i_sb);
-
-	return err;
-}
-
 /*
- * Handle write fault for VM_MIXEDMAP mappings. Similarly to ext4_dax_mkwrite()
+ * Handle write fault for VM_MIXEDMAP mappings. Similarly to ext4_dax_fault()
  * handler we check for races agaist truncate. Note that since we cycle through
  * i_mmap_sem, we are sure that also any hole punching that began before we
  * were called is finished by now and so if it included part of the file we
@@ -311,7 +307,7 @@  static int ext4_dax_pfn_mkwrite(struct vm_area_struct *vma,
 static const struct vm_operations_struct ext4_dax_vm_ops = {
 	.fault		= ext4_dax_fault,
 	.pmd_fault	= ext4_dax_pmd_fault,
-	.page_mkwrite	= ext4_dax_mkwrite,
+	.page_mkwrite	= ext4_dax_fault,
 	.pfn_mkwrite	= ext4_dax_pfn_mkwrite,
 };
 #else
diff --git a/fs/xfs/xfs_file.c b/fs/xfs/xfs_file.c
index 55e16e2..81edbd4 100644
--- a/fs/xfs/xfs_file.c
+++ b/fs/xfs/xfs_file.c
@@ -1523,16 +1523,26 @@  xfs_filemap_page_mkwrite(
 
 	sb_start_pagefault(inode->i_sb);
 	file_update_time(vma->vm_file);
-	xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 
 	if (IS_DAX(inode)) {
-		ret = __dax_mkwrite(vma, vmf, xfs_get_blocks_dax_fault, NULL);
+		xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
+		ret = __dax_mkwrite(vma, vmf, xfs_get_blocks_dax_fault, NULL,
+				false);
+		xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
+
+		if (ret == -EAGAIN) {
+			xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_EXCL);
+			ret = __dax_mkwrite(vma, vmf,
+					xfs_get_blocks_dax_fault, NULL, true);
+			xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_EXCL);
+		}
 	} else {
+		xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 		ret = block_page_mkwrite(vma, vmf, xfs_get_blocks);
 		ret = block_page_mkwrite_return(ret);
+		xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 	}
 
-	xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 	sb_end_pagefault(inode->i_sb);
 
 	return ret;
@@ -1560,7 +1570,8 @@  xfs_filemap_fault(
 		 * changes to xfs_get_blocks_direct() to map unwritten extent
 		 * ioend for conversion on read-only mappings.
 		 */
-		ret = __dax_fault(vma, vmf, xfs_get_blocks_dax_fault, NULL);
+		ret = __dax_fault(vma, vmf, xfs_get_blocks_dax_fault, NULL,
+				false);
 	} else
 		ret = filemap_fault(vma, vmf);
 	xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
@@ -1598,9 +1609,16 @@  xfs_filemap_pmd_fault(
 
 	xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 	ret = __dax_pmd_fault(vma, addr, pmd, flags, xfs_get_blocks_dax_fault,
-			      NULL);
+			      NULL, false);
 	xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_SHARED);
 
+	if (ret == -EAGAIN) {
+		xfs_ilock(XFS_I(inode), XFS_MMAPLOCK_EXCL);
+		ret = __dax_pmd_fault(vma, addr, pmd, flags,
+				xfs_get_blocks_dax_fault, NULL, true);
+		xfs_iunlock(XFS_I(inode), XFS_MMAPLOCK_EXCL);
+	}
+
 	if (flags & FAULT_FLAG_WRITE)
 		sb_end_pagefault(inode->i_sb);
 
diff --git a/include/linux/dax.h b/include/linux/dax.h
index 8204c3d..783a2b6 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -13,12 +13,13 @@  int dax_truncate_page(struct inode *, loff_t from, get_block_t);
 int dax_fault(struct vm_area_struct *, struct vm_fault *, get_block_t,
 		dax_iodone_t);
 int __dax_fault(struct vm_area_struct *, struct vm_fault *, get_block_t,
-		dax_iodone_t);
+		dax_iodone_t, bool alloc_ok);
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
 int dax_pmd_fault(struct vm_area_struct *, unsigned long addr, pmd_t *,
 				unsigned int flags, get_block_t, dax_iodone_t);
 int __dax_pmd_fault(struct vm_area_struct *, unsigned long addr, pmd_t *,
-				unsigned int flags, get_block_t, dax_iodone_t);
+				unsigned int flags, get_block_t, dax_iodone_t,
+				bool alloc_ok);
 #else
 static inline int dax_pmd_fault(struct vm_area_struct *vma, unsigned long addr,
 				pmd_t *pmd, unsigned int flags, get_block_t gb,
@@ -30,7 +31,8 @@  static inline int dax_pmd_fault(struct vm_area_struct *vma, unsigned long addr,
 #endif
 int dax_pfn_mkwrite(struct vm_area_struct *, struct vm_fault *);
 #define dax_mkwrite(vma, vmf, gb, iod)		dax_fault(vma, vmf, gb, iod)
-#define __dax_mkwrite(vma, vmf, gb, iod)	__dax_fault(vma, vmf, gb, iod)
+#define __dax_mkwrite(vma, vmf, gb, iod, alloc)  \
+	__dax_fault(vma, vmf, gb, iod, alloc)
 
 static inline bool vma_is_dax(struct vm_area_struct *vma)
 {