powerpc/npu-dma.c: Fix deadlock in mmio_invalidate

Message ID 20180213031734.19831-1-alistair@popple.id.au
State Changes Requested
Headers show
Series
  • powerpc/npu-dma.c: Fix deadlock in mmio_invalidate
Related show

Commit Message

Alistair Popple Feb. 13, 2018, 3:17 a.m.
When sending TLB invalidates to the NPU we need to send extra flushes due
to a hardware issue. The original implementation would lock the all the
ATSD MMIO registers sequentially before unlocking and relocking each of
them sequentially to do the extra flush.

This introduced a deadlock as it is possible for one thread to hold one
ATSD register whilst waiting for another register to be freed while the
other thread is holding that register waiting for the one in the first
thread to be freed.

For example if there are two threads and two ATSD registers:

Thread A	Thread B
Acquire 1
Acquire 2
Release 1	Acquire 1
Wait 1		Wait 2

Both threads will be stuck waiting to acquire a register resulting in an
RCU stall warning or soft lockup.

This patch solves the deadlock by refactoring the code to ensure registers
are not released between flushes and to ensure all registers are either
acquired or released together and in order.

Fixes: bbd5ff50afff ("powerpc/powernv/npu-dma: Add explicit flush when sending an ATSD")
Signed-off-by: Alistair Popple <alistair@popple.id.au>
---

Michael,

This should probalby go to stable as well, although it's bigger than the 100
line limit mentioned in the stable kernel rules.

- Alistair

 arch/powerpc/platforms/powernv/npu-dma.c | 195 +++++++++++++++++--------------
 1 file changed, 109 insertions(+), 86 deletions(-)

Comments

Balbir Singh Feb. 13, 2018, 6:06 a.m. | #1
On Tue, 13 Feb 2018 14:17:34 +1100
Alistair Popple <alistair@popple.id.au> wrote:

> When sending TLB invalidates to the NPU we need to send extra flushes due
> to a hardware issue. The original implementation would lock the all the
> ATSD MMIO registers sequentially before unlocking and relocking each of
> them sequentially to do the extra flush.
> 
> This introduced a deadlock as it is possible for one thread to hold one
> ATSD register whilst waiting for another register to be freed while the
> other thread is holding that register waiting for the one in the first
> thread to be freed.
> 
> For example if there are two threads and two ATSD registers:
> 
> Thread A	Thread B
> Acquire 1
> Acquire 2
> Release 1	Acquire 1
> Wait 1		Wait 2
> 
> Both threads will be stuck waiting to acquire a register resulting in an
> RCU stall warning or soft lockup.
> 
> This patch solves the deadlock by refactoring the code to ensure registers
> are not released between flushes and to ensure all registers are either
> acquired or released together and in order.
> 
> Fixes: bbd5ff50afff ("powerpc/powernv/npu-dma: Add explicit flush when sending an ATSD")
> Signed-off-by: Alistair Popple <alistair@popple.id.au>
> ---
> 
> Michael,
> 
> This should probalby go to stable as well, although it's bigger than the 100
> line limit mentioned in the stable kernel rules.
> 
> - Alistair
> 
>  arch/powerpc/platforms/powernv/npu-dma.c | 195 +++++++++++++++++--------------
>  1 file changed, 109 insertions(+), 86 deletions(-)
> 
> diff --git a/arch/powerpc/platforms/powernv/npu-dma.c b/arch/powerpc/platforms/powernv/npu-dma.c
> index fb0a6dee9bce..5746b456dfa4 100644
> --- a/arch/powerpc/platforms/powernv/npu-dma.c
> +++ b/arch/powerpc/platforms/powernv/npu-dma.c
> @@ -408,6 +408,11 @@ struct npu_context {
>  	void *priv;
>  };
>  
> +struct mmio_atsd_reg {
> +	struct npu *npu;
> +	int reg;
> +};
> +

Is it just easier to move reg to inside of struct npu?

>  /*
>   * Find a free MMIO ATSD register and mark it in use. Return -ENOSPC
>   * if none are available.
> @@ -433,79 +438,83 @@ static void put_mmio_atsd_reg(struct npu *npu, int reg)
>  #define XTS_ATSD_AVA  1
>  #define XTS_ATSD_STAT 2
>  
> -static int mmio_launch_invalidate(struct npu *npu, unsigned long launch,
> -				unsigned long va)
> +static void mmio_launch_invalidate(struct mmio_atsd_reg *mmio_atsd_reg,
> +				unsigned long launch, unsigned long va)
>  {
> -	int mmio_atsd_reg;
> -
> -	do {
> -		mmio_atsd_reg = get_mmio_atsd_reg(npu);
> -		cpu_relax();
> -	} while (mmio_atsd_reg < 0);
> +	struct npu *npu = mmio_atsd_reg->npu;
> +	int reg = mmio_atsd_reg->reg;
>  
>  	__raw_writeq(cpu_to_be64(va),
> -		npu->mmio_atsd_regs[mmio_atsd_reg] + XTS_ATSD_AVA);
> +		npu->mmio_atsd_regs[reg] + XTS_ATSD_AVA);
>  	eieio();
> -	__raw_writeq(cpu_to_be64(launch), npu->mmio_atsd_regs[mmio_atsd_reg]);
> -
> -	return mmio_atsd_reg;
> +	__raw_writeq(cpu_to_be64(launch), npu->mmio_atsd_regs[reg]);
>  }
>  
> -static int mmio_invalidate_pid(struct npu *npu, unsigned long pid, bool flush)
> +static void mmio_invalidate_pid(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
> +				unsigned long pid, bool flush)
>  {
> +	int i;
>  	unsigned long launch;
>  
> -	/* IS set to invalidate matching PID */
> -	launch = PPC_BIT(12);
> +	for (i = 0; i <= max_npu2_index; i++) {
> +		if (mmio_atsd_reg[i].reg < 0)
> +			continue;
>  
> -	/* PRS set to process-scoped */
> -	launch |= PPC_BIT(13);
> +		/* IS set to invalidate matching PID */
> +		launch = PPC_BIT(12);
>  
> -	/* AP */
> -	launch |= (u64) mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
> +		/* PRS set to process-scoped */
> +		launch |= PPC_BIT(13);
>  
> -	/* PID */
> -	launch |= pid << PPC_BITLSHIFT(38);
> +		/* AP */
> +		launch |= (u64)
> +			mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
>  
> -	/* No flush */
> -	launch |= !flush << PPC_BITLSHIFT(39);
> +		/* PID */
> +		launch |= pid << PPC_BITLSHIFT(38);
>  
> -	/* Invalidating the entire process doesn't use a va */
> -	return mmio_launch_invalidate(npu, launch, 0);
> +		/* No flush */
> +		launch |= !flush << PPC_BITLSHIFT(39);
> +
> +		/* Invalidating the entire process doesn't use a va */
> +		mmio_launch_invalidate(&mmio_atsd_reg[i], launch, 0);
> +	}
>  }
>  
> -static int mmio_invalidate_va(struct npu *npu, unsigned long va,
> -			unsigned long pid, bool flush)
> +static void mmio_invalidate_va(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
> +			unsigned long va, unsigned long pid, bool flush)
>  {
> +	int i;
>  	unsigned long launch;
>  
> -	/* IS set to invalidate target VA */
> -	launch = 0;
> +	for (i = 0; i <= max_npu2_index; i++) {
> +		if (mmio_atsd_reg[i].reg < 0)
> +			continue;
> +
> +		/* IS set to invalidate target VA */
> +		launch = 0;
>  
> -	/* PRS set to process scoped */
> -	launch |= PPC_BIT(13);
> +		/* PRS set to process scoped */
> +		launch |= PPC_BIT(13);
>  
> -	/* AP */
> -	launch |= (u64) mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
> +		/* AP */
> +		launch |= (u64)
> +			mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
>  
> -	/* PID */
> -	launch |= pid << PPC_BITLSHIFT(38);
> +		/* PID */
> +		launch |= pid << PPC_BITLSHIFT(38);
>  
> -	/* No flush */
> -	launch |= !flush << PPC_BITLSHIFT(39);
> +		/* No flush */
> +		launch |= !flush << PPC_BITLSHIFT(39);
>  
> -	return mmio_launch_invalidate(npu, launch, va);
> +		mmio_launch_invalidate(&mmio_atsd_reg[i], launch, va);
> +	}
>  }
>  
>  #define mn_to_npu_context(x) container_of(x, struct npu_context, mn)
>  
> -struct mmio_atsd_reg {
> -	struct npu *npu;
> -	int reg;
> -};
> -
>  static void mmio_invalidate_wait(
> -	struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS], bool flush)
> +	struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
>  {
>  	struct npu *npu;
>  	int i, reg;
> @@ -520,16 +529,46 @@ static void mmio_invalidate_wait(
>  		reg = mmio_atsd_reg[i].reg;
>  		while (__raw_readq(npu->mmio_atsd_regs[reg] + XTS_ATSD_STAT))
>  			cpu_relax();
> +	}
> +}
>  
> -		put_mmio_atsd_reg(npu, reg);
> +static void acquire_atsd_reg(struct npu_context *npu_context,
> +			struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
> +{
> +	int i, j;
> +	struct npu *npu;
> +	struct pci_dev *npdev;
> +	struct pnv_phb *nphb;
>  
> -		/*
> -		 * The GPU requires two flush ATSDs to ensure all entries have
> -		 * been flushed. We use PID 0 as it will never be used for a
> -		 * process on the GPU.
> -		 */
> -		if (flush)
> -			mmio_invalidate_pid(npu, 0, true);
> +	for (i = 0; i <= max_npu2_index; i++) {
> +		mmio_atsd_reg[i].reg = -1;
> +		for (j = 0; j < NV_MAX_LINKS; j++) {

Is it safe to assume that npu_context->npdev will not change in this
loop? I guess it would need to be stronger than just this loop.

> +			npdev = npu_context->npdev[i][j];
> +			if (!npdev)
> +				continue;
> +
> +			nphb = pci_bus_to_host(npdev->bus)->private_data;
> +			npu = &nphb->npu;
> +			mmio_atsd_reg[i].npu = npu;
> +			mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> +			while (mmio_atsd_reg[i].reg < 0) {
> +				mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> +				cpu_relax();

A cond_resched() as well if we have too many tries?

Balbir
Alistair Popple Feb. 14, 2018, 3:23 a.m. | #2
> > +struct mmio_atsd_reg {
> > +	struct npu *npu;
> > +	int reg;
> > +};
> > +
> 
> Is it just easier to move reg to inside of struct npu?

I don't think so, struct npu is global to all npu contexts where as this is
specific to the given invalidation. We don't have enough registers to assign
each NPU context it's own dedicated register so I'm not sure it makes sense to
put it there either.

> > +static void acquire_atsd_reg(struct npu_context *npu_context,
> > +			struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
> > +{
> > +	int i, j;
> > +	struct npu *npu;
> > +	struct pci_dev *npdev;
> > +	struct pnv_phb *nphb;
> >  
> > -		/*
> > -		 * The GPU requires two flush ATSDs to ensure all entries have
> > -		 * been flushed. We use PID 0 as it will never be used for a
> > -		 * process on the GPU.
> > -		 */
> > -		if (flush)
> > -			mmio_invalidate_pid(npu, 0, true);
> > +	for (i = 0; i <= max_npu2_index; i++) {
> > +		mmio_atsd_reg[i].reg = -1;
> > +		for (j = 0; j < NV_MAX_LINKS; j++) {
> 
> Is it safe to assume that npu_context->npdev will not change in this
> loop? I guess it would need to be stronger than just this loop.

It is not safe to assume that npu_context->npdev won't change during this loop,
however I don't think it is a problem if it does as we only read each element
once during the invalidation.

There are two possibilities for how this could change. pnv_npu2_init_context()
will add a nvlink to the npdev which will result in the TLB invalidation being
sent to that GPU as well which should not be a problem.

pnv_npu2_destroy_context() will remove the the nvlink from npdev. If it happens
prior to this loop it should not be a problem (as the destruction will have
already invalidated the GPU TLB). If it happens after this loop it shouldn't be
a problem either (it will just result in an extra TLB invalidate being sent to
this GPU).

> > +			npdev = npu_context->npdev[i][j];
> > +			if (!npdev)
> > +				continue;
> > +
> > +			nphb = pci_bus_to_host(npdev->bus)->private_data;
> > +			npu = &nphb->npu;
> > +			mmio_atsd_reg[i].npu = npu;
> > +			mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> > +			while (mmio_atsd_reg[i].reg < 0) {
> > +				mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> > +				cpu_relax();
> 
> A cond_resched() as well if we have too many tries?

I don't think we can as the invalidate_range() function is called under the ptl
spin-lock and is not allowed to sleep (at least according to
include/linux/mmu_notifier.h).

- Alistair

> Balbir
>
Mark Hairgrove Feb. 16, 2018, 3:11 a.m. | #3
On Wed, 14 Feb 2018, Alistair Popple wrote:

> > > +struct mmio_atsd_reg {
> > > +	struct npu *npu;
> > > +	int reg;
> > > +};
> > > +
> > 
> > Is it just easier to move reg to inside of struct npu?
> 
> I don't think so, struct npu is global to all npu contexts where as this is
> specific to the given invalidation. We don't have enough registers to assign
> each NPU context it's own dedicated register so I'm not sure it makes sense to
> put it there either.
> 
> > > +static void acquire_atsd_reg(struct npu_context *npu_context,
> > > +			struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
> > > +{
> > > +	int i, j;
> > > +	struct npu *npu;
> > > +	struct pci_dev *npdev;
> > > +	struct pnv_phb *nphb;
> > >  
> > > -		/*
> > > -		 * The GPU requires two flush ATSDs to ensure all entries have
> > > -		 * been flushed. We use PID 0 as it will never be used for a
> > > -		 * process on the GPU.
> > > -		 */
> > > -		if (flush)
> > > -			mmio_invalidate_pid(npu, 0, true);
> > > +	for (i = 0; i <= max_npu2_index; i++) {
> > > +		mmio_atsd_reg[i].reg = -1;
> > > +		for (j = 0; j < NV_MAX_LINKS; j++) {
> > 
> > Is it safe to assume that npu_context->npdev will not change in this
> > loop? I guess it would need to be stronger than just this loop.
> 
> It is not safe to assume that npu_context->npdev won't change during this loop,
> however I don't think it is a problem if it does as we only read each element
> once during the invalidation.

Shouldn't that be enforced with READ_ONCE() then?

I assume that npdev->bus can't change until after the last
pnv_npu2_destroy_context() is called for an npu. In that case, the
mmu_notifier_unregister() in pnv_npu2_release_context() will block until
mmio_invalidate() is done using npdev. That seems safe enough, but a
comment somewhere about that would be useful.

> 
> There are two possibilities for how this could change. pnv_npu2_init_context()
> will add a nvlink to the npdev which will result in the TLB invalidation being
> sent to that GPU as well which should not be a problem.
> 
> pnv_npu2_destroy_context() will remove the the nvlink from npdev. If it happens
> prior to this loop it should not be a problem (as the destruction will have
> already invalidated the GPU TLB). If it happens after this loop it shouldn't be
> a problem either (it will just result in an extra TLB invalidate being sent to
> this GPU).
> 
> > > +			npdev = npu_context->npdev[i][j];
> > > +			if (!npdev)
> > > +				continue;
> > > +
> > > +			nphb = pci_bus_to_host(npdev->bus)->private_data;
> > > +			npu = &nphb->npu;
> > > +			mmio_atsd_reg[i].npu = npu;
> > > +			mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> > > +			while (mmio_atsd_reg[i].reg < 0) {
> > > +				mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> > > +				cpu_relax();
> > 
> > A cond_resched() as well if we have too many tries?
> 
> I don't think we can as the invalidate_range() function is called under the ptl
> spin-lock and is not allowed to sleep (at least according to
> include/linux/mmu_notifier.h).
> 
> - Alistair
> 
> > Balbir
> > 
> 
> 
>
Balbir Singh Feb. 19, 2018, 2:57 a.m. | #4
On Thu, 15 Feb 2018 19:11:19 -0800
Mark Hairgrove <mhairgrove@nvidia.com> wrote:

> On Wed, 14 Feb 2018, Alistair Popple wrote:
> 
> > > > +struct mmio_atsd_reg {
> > > > +	struct npu *npu;
> > > > +	int reg;
> > > > +};
> > > > +  
> > > 
> > > Is it just easier to move reg to inside of struct npu?  
> > 
> > I don't think so, struct npu is global to all npu contexts where as this is
> > specific to the given invalidation. We don't have enough registers to assign
> > each NPU context it's own dedicated register so I'm not sure it makes sense to
> > put it there either.

Fair enough, also discussed this offline with you.

> >   
> > > > +static void acquire_atsd_reg(struct npu_context *npu_context,
> > > > +			struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
> > > > +{
> > > > +	int i, j;
> > > > +	struct npu *npu;
> > > > +	struct pci_dev *npdev;
> > > > +	struct pnv_phb *nphb;
> > > >  
> > > > -		/*
> > > > -		 * The GPU requires two flush ATSDs to ensure all entries have
> > > > -		 * been flushed. We use PID 0 as it will never be used for a
> > > > -		 * process on the GPU.
> > > > -		 */
> > > > -		if (flush)
> > > > -			mmio_invalidate_pid(npu, 0, true);
> > > > +	for (i = 0; i <= max_npu2_index; i++) {
> > > > +		mmio_atsd_reg[i].reg = -1;
> > > > +		for (j = 0; j < NV_MAX_LINKS; j++) {  
> > > 
> > > Is it safe to assume that npu_context->npdev will not change in this
> > > loop? I guess it would need to be stronger than just this loop.  
> > 
> > It is not safe to assume that npu_context->npdev won't change during this loop,
> > however I don't think it is a problem if it does as we only read each element
> > once during the invalidation.  
> 
> Shouldn't that be enforced with READ_ONCE() then?

Good point, although I think the acquire_* function itself may be called
from a higher layer with the mmap_sem always held. I wonder if we need
barriers around get and put mmio_atsd_reg.


> 
> I assume that npdev->bus can't change until after the last
> pnv_npu2_destroy_context() is called for an npu. In that case, the
> mmu_notifier_unregister() in pnv_npu2_release_context() will block until
> mmio_invalidate() is done using npdev. That seems safe enough, but a
> comment somewhere about that would be useful.
> 
> > 
> > There are two possibilities for how this could change. pnv_npu2_init_context()
> > will add a nvlink to the npdev which will result in the TLB invalidation being
> > sent to that GPU as well which should not be a problem.
> > 
> > pnv_npu2_destroy_context() will remove the the nvlink from npdev. If it happens
> > prior to this loop it should not be a problem (as the destruction will have
> > already invalidated the GPU TLB). If it happens after this loop it shouldn't be
> > a problem either (it will just result in an extra TLB invalidate being sent to
> > this GPU).
> >   
> > > > +			npdev = npu_context->npdev[i][j];
> > > > +			if (!npdev)
> > > > +				continue;
> > > > +
> > > > +			nphb = pci_bus_to_host(npdev->bus)->private_data;
> > > > +			npu = &nphb->npu;
> > > > +			mmio_atsd_reg[i].npu = npu;
> > > > +			mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> > > > +			while (mmio_atsd_reg[i].reg < 0) {
> > > > +				mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
> > > > +				cpu_relax();  
> > > 
> > > A cond_resched() as well if we have too many tries?  
> > 
> > I don't think we can as the invalidate_range() function is called under the ptl
> > spin-lock and is not allowed to sleep (at least according to
> > include/linux/mmu_notifier.h).

I double checked, It's the reverse

	 /*
          * If both of these callbacks cannot block, mmu_notifier_ops.flags
          * should have MMU_INVALIDATE_DOES_NOT_BLOCK set.
          */
         void (*invalidate_range_start)(struct mmu_notifier *mn,
                                        struct mm_struct *mm,
                                        unsigned long start, unsigned long end);
         void (*invalidate_range_end)(struct mmu_notifier *mn,
                                      struct mm_struct *mm,
                                      unsigned long start, unsigned long end);



> > 
> > - Alistair
> >   
> > > Balbir
> > >   
> > 
> > 
> >   

I think it looks good to me otherwise,

Balbir Singh.
Alistair Popple Feb. 19, 2018, 5:02 a.m. | #5
> > Shouldn't that be enforced with READ_ONCE() then?

Yep, I can add that.

> Good point, although I think the acquire_* function itself may be called
> from a higher layer with the mmap_sem always held. I wonder if we need
> barriers around get and put mmio_atsd_reg.

test_and_set_bit() should imply a memory barrier so I don't think we need one
there (and looking at the implementation there is one). clear_bit() might need
one though. For that I guess I could use clear_bit_unlock()? There is also a
matching test_and_set_bit_lock() so I will submit a v2 which uses those instead
given we are using these like a lock.

> > > I don't think we can as the invalidate_range() function is called under the ptl
> > > spin-lock and is not allowed to sleep (at least according to
> > > include/linux/mmu_notifier.h).
> 
> I double checked, It's the reverse
> 
> 	 /*
>           * If both of these callbacks cannot block, mmu_notifier_ops.flags
>           * should have MMU_INVALIDATE_DOES_NOT_BLOCK set.
>           */

Argh, that must have been merged during the current window. Thanks for pointing
out - I will submit a seperate patch to update the mmu_notifier_ops.flags to set
MMU_INVALIDATE_DOES_NOT_BLOCK.

- Alistair

>          void (*invalidate_range_start)(struct mmu_notifier *mn,
>                                         struct mm_struct *mm,
>                                         unsigned long start, unsigned long end);
>          void (*invalidate_range_end)(struct mmu_notifier *mn,
>                                       struct mm_struct *mm,
>                                       unsigned long start, unsigned long end);
> > > 
> > > - Alistair
> > >   
> > > > Balbir
> > > >   
> > > 
> > > 
> > >   
> 
> I think it looks good to me otherwise,
> 
> Balbir Singh.
>
Mark Hairgrove Feb. 20, 2018, 6:34 p.m. | #6
On Mon, 19 Feb 2018, Balbir Singh wrote:

> Good point, although I think the acquire_* function itself may be called
> from a higher layer with the mmap_sem always held. I wonder if we need
> barriers around get and put mmio_atsd_reg.

I agree with the need for memory barriers. FWIW, page tables can be
invalidated without mmap_sem being held, for example by
unmap_mapping_range.

Patch

diff --git a/arch/powerpc/platforms/powernv/npu-dma.c b/arch/powerpc/platforms/powernv/npu-dma.c
index fb0a6dee9bce..5746b456dfa4 100644
--- a/arch/powerpc/platforms/powernv/npu-dma.c
+++ b/arch/powerpc/platforms/powernv/npu-dma.c
@@ -408,6 +408,11 @@  struct npu_context {
 	void *priv;
 };
 
+struct mmio_atsd_reg {
+	struct npu *npu;
+	int reg;
+};
+
 /*
  * Find a free MMIO ATSD register and mark it in use. Return -ENOSPC
  * if none are available.
@@ -433,79 +438,83 @@  static void put_mmio_atsd_reg(struct npu *npu, int reg)
 #define XTS_ATSD_AVA  1
 #define XTS_ATSD_STAT 2
 
-static int mmio_launch_invalidate(struct npu *npu, unsigned long launch,
-				unsigned long va)
+static void mmio_launch_invalidate(struct mmio_atsd_reg *mmio_atsd_reg,
+				unsigned long launch, unsigned long va)
 {
-	int mmio_atsd_reg;
-
-	do {
-		mmio_atsd_reg = get_mmio_atsd_reg(npu);
-		cpu_relax();
-	} while (mmio_atsd_reg < 0);
+	struct npu *npu = mmio_atsd_reg->npu;
+	int reg = mmio_atsd_reg->reg;
 
 	__raw_writeq(cpu_to_be64(va),
-		npu->mmio_atsd_regs[mmio_atsd_reg] + XTS_ATSD_AVA);
+		npu->mmio_atsd_regs[reg] + XTS_ATSD_AVA);
 	eieio();
-	__raw_writeq(cpu_to_be64(launch), npu->mmio_atsd_regs[mmio_atsd_reg]);
-
-	return mmio_atsd_reg;
+	__raw_writeq(cpu_to_be64(launch), npu->mmio_atsd_regs[reg]);
 }
 
-static int mmio_invalidate_pid(struct npu *npu, unsigned long pid, bool flush)
+static void mmio_invalidate_pid(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
+				unsigned long pid, bool flush)
 {
+	int i;
 	unsigned long launch;
 
-	/* IS set to invalidate matching PID */
-	launch = PPC_BIT(12);
+	for (i = 0; i <= max_npu2_index; i++) {
+		if (mmio_atsd_reg[i].reg < 0)
+			continue;
 
-	/* PRS set to process-scoped */
-	launch |= PPC_BIT(13);
+		/* IS set to invalidate matching PID */
+		launch = PPC_BIT(12);
 
-	/* AP */
-	launch |= (u64) mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
+		/* PRS set to process-scoped */
+		launch |= PPC_BIT(13);
 
-	/* PID */
-	launch |= pid << PPC_BITLSHIFT(38);
+		/* AP */
+		launch |= (u64)
+			mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
 
-	/* No flush */
-	launch |= !flush << PPC_BITLSHIFT(39);
+		/* PID */
+		launch |= pid << PPC_BITLSHIFT(38);
 
-	/* Invalidating the entire process doesn't use a va */
-	return mmio_launch_invalidate(npu, launch, 0);
+		/* No flush */
+		launch |= !flush << PPC_BITLSHIFT(39);
+
+		/* Invalidating the entire process doesn't use a va */
+		mmio_launch_invalidate(&mmio_atsd_reg[i], launch, 0);
+	}
 }
 
-static int mmio_invalidate_va(struct npu *npu, unsigned long va,
-			unsigned long pid, bool flush)
+static void mmio_invalidate_va(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
+			unsigned long va, unsigned long pid, bool flush)
 {
+	int i;
 	unsigned long launch;
 
-	/* IS set to invalidate target VA */
-	launch = 0;
+	for (i = 0; i <= max_npu2_index; i++) {
+		if (mmio_atsd_reg[i].reg < 0)
+			continue;
+
+		/* IS set to invalidate target VA */
+		launch = 0;
 
-	/* PRS set to process scoped */
-	launch |= PPC_BIT(13);
+		/* PRS set to process scoped */
+		launch |= PPC_BIT(13);
 
-	/* AP */
-	launch |= (u64) mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
+		/* AP */
+		launch |= (u64)
+			mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
 
-	/* PID */
-	launch |= pid << PPC_BITLSHIFT(38);
+		/* PID */
+		launch |= pid << PPC_BITLSHIFT(38);
 
-	/* No flush */
-	launch |= !flush << PPC_BITLSHIFT(39);
+		/* No flush */
+		launch |= !flush << PPC_BITLSHIFT(39);
 
-	return mmio_launch_invalidate(npu, launch, va);
+		mmio_launch_invalidate(&mmio_atsd_reg[i], launch, va);
+	}
 }
 
 #define mn_to_npu_context(x) container_of(x, struct npu_context, mn)
 
-struct mmio_atsd_reg {
-	struct npu *npu;
-	int reg;
-};
-
 static void mmio_invalidate_wait(
-	struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS], bool flush)
+	struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
 {
 	struct npu *npu;
 	int i, reg;
@@ -520,16 +529,46 @@  static void mmio_invalidate_wait(
 		reg = mmio_atsd_reg[i].reg;
 		while (__raw_readq(npu->mmio_atsd_regs[reg] + XTS_ATSD_STAT))
 			cpu_relax();
+	}
+}
 
-		put_mmio_atsd_reg(npu, reg);
+static void acquire_atsd_reg(struct npu_context *npu_context,
+			struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
+{
+	int i, j;
+	struct npu *npu;
+	struct pci_dev *npdev;
+	struct pnv_phb *nphb;
 
-		/*
-		 * The GPU requires two flush ATSDs to ensure all entries have
-		 * been flushed. We use PID 0 as it will never be used for a
-		 * process on the GPU.
-		 */
-		if (flush)
-			mmio_invalidate_pid(npu, 0, true);
+	for (i = 0; i <= max_npu2_index; i++) {
+		mmio_atsd_reg[i].reg = -1;
+		for (j = 0; j < NV_MAX_LINKS; j++) {
+			npdev = npu_context->npdev[i][j];
+			if (!npdev)
+				continue;
+
+			nphb = pci_bus_to_host(npdev->bus)->private_data;
+			npu = &nphb->npu;
+			mmio_atsd_reg[i].npu = npu;
+			mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
+			while (mmio_atsd_reg[i].reg < 0) {
+				mmio_atsd_reg[i].reg = get_mmio_atsd_reg(npu);
+				cpu_relax();
+			}
+			break;
+		}
+	}
+}
+
+static void release_atsd_reg(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS])
+{
+	int i;
+
+	for (i = 0; i <= max_npu2_index; i++) {
+		if (mmio_atsd_reg[i].reg < 0)
+			continue;
+
+		put_mmio_atsd_reg(mmio_atsd_reg[i].npu, mmio_atsd_reg[i].reg);
 	}
 }
 
@@ -540,10 +579,6 @@  static void mmio_invalidate_wait(
 static void mmio_invalidate(struct npu_context *npu_context, int va,
 			unsigned long address, bool flush)
 {
-	int i, j;
-	struct npu *npu;
-	struct pnv_phb *nphb;
-	struct pci_dev *npdev;
 	struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS];
 	unsigned long pid = npu_context->mm->context.id;
 
@@ -559,37 +594,25 @@  static void mmio_invalidate(struct npu_context *npu_context, int va,
 	 * Loop over all the NPUs this process is active on and launch
 	 * an invalidate.
 	 */
-	for (i = 0; i <= max_npu2_index; i++) {
-		mmio_atsd_reg[i].reg = -1;
-		for (j = 0; j < NV_MAX_LINKS; j++) {
-			npdev = npu_context->npdev[i][j];
-			if (!npdev)
-				continue;
-
-			nphb = pci_bus_to_host(npdev->bus)->private_data;
-			npu = &nphb->npu;
-			mmio_atsd_reg[i].npu = npu;
-
-			if (va)
-				mmio_atsd_reg[i].reg =
-					mmio_invalidate_va(npu, address, pid,
-							flush);
-			else
-				mmio_atsd_reg[i].reg =
-					mmio_invalidate_pid(npu, pid, flush);
-
-			/*
-			 * The NPU hardware forwards the shootdown to all GPUs
-			 * so we only have to launch one shootdown per NPU.
-			 */
-			break;
-		}
+	acquire_atsd_reg(npu_context, mmio_atsd_reg);
+	if (va)
+		mmio_invalidate_va(mmio_atsd_reg, address, pid, flush);
+	else
+		mmio_invalidate_pid(mmio_atsd_reg, pid, flush);
+
+	mmio_invalidate_wait(mmio_atsd_reg);
+	if (flush) {
+		/*
+		 * The GPU requires two flush ATSDs to ensure all entries have
+		 * been flushed. We use PID 0 as it will never be used for a
+		 * process on the GPU.
+		 */
+		mmio_invalidate_pid(mmio_atsd_reg, 0, true);
+		mmio_invalidate_wait(mmio_atsd_reg);
+		mmio_invalidate_pid(mmio_atsd_reg, 0, true);
+		mmio_invalidate_wait(mmio_atsd_reg);
 	}
-
-	mmio_invalidate_wait(mmio_atsd_reg, flush);
-	if (flush)
-		/* Wait for the flush to complete */
-		mmio_invalidate_wait(mmio_atsd_reg, false);
+	release_atsd_reg(mmio_atsd_reg);
 }
 
 static void pnv_npu2_mn_release(struct mmu_notifier *mn,