diff --git a/drivers/pci/pci.c b/drivers/pci/pci.c
index 5485883..c841aa6 100644
--- a/drivers/pci/pci.c
+++ b/drivers/pci/pci.c
@@ -2016,13 +2016,14 @@ void pci_free_cap_save_buffers(struct pci_dev *dev)
 void pci_enable_ari(struct pci_dev *dev)
 {
        u32 cap;
+       bool enable = true;
        struct pci_dev *bridge;

        if (pcie_ari_disabled || !pci_is_pcie(dev) || dev->devfn)
                return;

        if (!pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI))
-               return;
+               enable = false;

        bridge = dev->bus->self;
        if (!bridge)
@@ -2032,8 +2033,15 @@ void pci_enable_ari(struct pci_dev *dev)
        if (!(cap & PCI_EXP_DEVCAP2_ARI))
                return;

-       pcie_capability_set_word(bridge, PCI_EXP_DEVCTL2, PCI_EXP_DEVCTL2_ARI);
-       bridge->ari_enabled = 1;
+       if (enable) {
+               pcie_capability_set_word(bridge, PCI_EXP_DEVCTL2,
+                                        PCI_EXP_DEVCTL2_ARI);
+               bridge->ari_enabled = 1;
+       } else {
+               pcie_capability_clear_word(bridge, PCI_EXP_DEVCTL2,
+                                        PCI_EXP_DEVCTL2_ARI);
+               bridge->ari_enabled = 0;
+       }
 }

 /**
