diff mbox series

[v13,03/10] virtio-iommu: Implement attach/detach command

Message ID 20200125171955.12825-4-eric.auger@redhat.com
State New
Headers show
Series VIRTIO-IOMMU device | expand

Commit Message

Eric Auger Jan. 25, 2020, 5:19 p.m. UTC
This patch implements the endpoint attach/detach to/from
a domain.

Domain and endpoint internal datatypes are introduced.
Both are stored in RB trees. The domain owns a list of
endpoints attached to it. Also helpers to get/put
end points and domains are introduced.

As for the IOMMU memory regions, a callback is called on
PCI bus enumeration that initializes for a given device
on the bus hierarchy an IOMMU memory region. The PCI bus
hierarchy is stored locally in IOMMUPciBus and IOMMUDevice
objects.

At the time of the enumeration, the bus number may not be
computed yet.

So operations that will need to retrieve the IOMMUdevice
and its IOMMU memory region from the bus number and devfn,
once the bus number is garanteed to be frozen, use an array
of IOMMUPciBus, lazily populated.

Signed-off-by: Eric Auger <eric.auger@redhat.com>

---

v12 -> v13:
- squashed v12 4, 5, 6 into this patch
- rename virtio_iommu_get_sid into virtio_iommu_get_bdf

v11 -> v12:
- check the device is protected by the iommu on attach
- on detach, check the domain id the device is attached to matches
  the one used in the detach command
- fix mapping ref counter and destroy the domain when no end-points
  are attached anymore
---
 hw/virtio/trace-events           |   6 +
 hw/virtio/virtio-iommu.c         | 315 ++++++++++++++++++++++++++++++-
 include/hw/virtio/virtio-iommu.h |   3 +
 3 files changed, 322 insertions(+), 2 deletions(-)

Comments

Peter Xu Feb. 3, 2020, 1:49 p.m. UTC | #1
On Sat, Jan 25, 2020 at 06:19:48PM +0100, Eric Auger wrote:
> This patch implements the endpoint attach/detach to/from
> a domain.
> 
> Domain and endpoint internal datatypes are introduced.
> Both are stored in RB trees. The domain owns a list of
> endpoints attached to it. Also helpers to get/put
> end points and domains are introduced.
> 
> As for the IOMMU memory regions, a callback is called on
> PCI bus enumeration that initializes for a given device
> on the bus hierarchy an IOMMU memory region. The PCI bus
> hierarchy is stored locally in IOMMUPciBus and IOMMUDevice
> objects.
> 
> At the time of the enumeration, the bus number may not be
> computed yet.
> 
> So operations that will need to retrieve the IOMMUdevice
> and its IOMMU memory region from the bus number and devfn,
> once the bus number is garanteed to be frozen, use an array
> of IOMMUPciBus, lazily populated.
> 
> Signed-off-by: Eric Auger <eric.auger@redhat.com>
> 
> ---
> 
> v12 -> v13:
> - squashed v12 4, 5, 6 into this patch
> - rename virtio_iommu_get_sid into virtio_iommu_get_bdf
> 
> v11 -> v12:
> - check the device is protected by the iommu on attach
> - on detach, check the domain id the device is attached to matches
>   the one used in the detach command
> - fix mapping ref counter and destroy the domain when no end-points
>   are attached anymore
> ---
>  hw/virtio/trace-events           |   6 +
>  hw/virtio/virtio-iommu.c         | 315 ++++++++++++++++++++++++++++++-
>  include/hw/virtio/virtio-iommu.h |   3 +
>  3 files changed, 322 insertions(+), 2 deletions(-)
> 
> diff --git a/hw/virtio/trace-events b/hw/virtio/trace-events
> index f7141aa2f6..15595f8cd7 100644
> --- a/hw/virtio/trace-events
> +++ b/hw/virtio/trace-events
> @@ -64,3 +64,9 @@ virtio_iommu_attach(uint32_t domain_id, uint32_t ep_id) "domain=%d endpoint=%d"
>  virtio_iommu_detach(uint32_t domain_id, uint32_t ep_id) "domain=%d endpoint=%d"
>  virtio_iommu_map(uint32_t domain_id, uint64_t virt_start, uint64_t virt_end, uint64_t phys_start, uint32_t flags) "domain=%d virt_start=0x%"PRIx64" virt_end=0x%"PRIx64 " phys_start=0x%"PRIx64" flags=%d"
>  virtio_iommu_unmap(uint32_t domain_id, uint64_t virt_start, uint64_t virt_end) "domain=%d virt_start=0x%"PRIx64" virt_end=0x%"PRIx64
> +virtio_iommu_translate(const char *name, uint32_t rid, uint64_t iova, int flag) "mr=%s rid=%d addr=0x%"PRIx64" flag=%d"
> +virtio_iommu_init_iommu_mr(char *iommu_mr) "init %s"
> +virtio_iommu_get_endpoint(uint32_t ep_id) "Alloc endpoint=%d"
> +virtio_iommu_put_endpoint(uint32_t ep_id) "Free endpoint=%d"
> +virtio_iommu_get_domain(uint32_t domain_id) "Alloc domain=%d"
> +virtio_iommu_put_domain(uint32_t domain_id) "Free domain=%d"
> diff --git a/hw/virtio/virtio-iommu.c b/hw/virtio/virtio-iommu.c
> index 9d1b997df7..e5cc94138b 100644
> --- a/hw/virtio/virtio-iommu.c
> +++ b/hw/virtio/virtio-iommu.c
> @@ -23,6 +23,8 @@
>  #include "hw/qdev-properties.h"
>  #include "hw/virtio/virtio.h"
>  #include "sysemu/kvm.h"
> +#include "qapi/error.h"
> +#include "qemu/error-report.h"
>  #include "trace.h"
>  
>  #include "standard-headers/linux/virtio_ids.h"
> @@ -30,19 +32,234 @@
>  #include "hw/virtio/virtio-bus.h"
>  #include "hw/virtio/virtio-access.h"
>  #include "hw/virtio/virtio-iommu.h"
> +#include "hw/pci/pci_bus.h"
> +#include "hw/pci/pci.h"
>  
>  /* Max size */
>  #define VIOMMU_DEFAULT_QUEUE_SIZE 256
>  
> +typedef struct VirtIOIOMMUDomain {
> +    uint32_t id;
> +    GTree *mappings;
> +    QLIST_HEAD(, VirtIOIOMMUEndpoint) endpoint_list;
> +} VirtIOIOMMUDomain;
> +
> +typedef struct VirtIOIOMMUEndpoint {
> +    uint32_t id;
> +    VirtIOIOMMUDomain *domain;
> +    QLIST_ENTRY(VirtIOIOMMUEndpoint) next;
> +} VirtIOIOMMUEndpoint;
> +
> +typedef struct VirtIOIOMMUInterval {
> +    uint64_t low;
> +    uint64_t high;
> +} VirtIOIOMMUInterval;
> +
> +static inline uint16_t virtio_iommu_get_bdf(IOMMUDevice *dev)
> +{
> +    return PCI_BUILD_BDF(pci_bus_num(dev->bus), dev->devfn);
> +}
> +
> +/**
> + * The bus number is used for lookup when SID based operations occur.
> + * In that case we lazily populate the IOMMUPciBus array from the bus hash
> + * table. At the time the IOMMUPciBus is created (iommu_find_add_as), the bus
> + * numbers may not be always initialized yet.
> + */
> +static IOMMUPciBus *iommu_find_iommu_pcibus(VirtIOIOMMU *s, uint8_t bus_num)
> +{
> +    IOMMUPciBus *iommu_pci_bus = s->iommu_pcibus_by_bus_num[bus_num];
> +
> +    if (!iommu_pci_bus) {
> +        GHashTableIter iter;
> +
> +        g_hash_table_iter_init(&iter, s->as_by_busptr);
> +        while (g_hash_table_iter_next(&iter, NULL, (void **)&iommu_pci_bus)) {
> +            if (pci_bus_num(iommu_pci_bus->bus) == bus_num) {
> +                s->iommu_pcibus_by_bus_num[bus_num] = iommu_pci_bus;
> +                return iommu_pci_bus;
> +            }
> +        }
> +        return NULL;
> +    }
> +    return iommu_pci_bus;
> +}
> +
> +static IOMMUMemoryRegion *virtio_iommu_mr(VirtIOIOMMU *s, uint32_t sid)
> +{
> +    uint8_t bus_n, devfn;
> +    IOMMUPciBus *iommu_pci_bus;
> +    IOMMUDevice *dev;
> +
> +    bus_n = PCI_BUS_NUM(sid);
> +    iommu_pci_bus = iommu_find_iommu_pcibus(s, bus_n);
> +    if (iommu_pci_bus) {
> +        devfn = sid & PCI_DEVFN_MAX;
> +        dev = iommu_pci_bus->pbdev[devfn];
> +        if (dev) {
> +            return &dev->iommu_mr;
> +        }
> +    }
> +    return NULL;
> +}
> +
> +static gint interval_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
> +{
> +    VirtIOIOMMUInterval *inta = (VirtIOIOMMUInterval *)a;
> +    VirtIOIOMMUInterval *intb = (VirtIOIOMMUInterval *)b;
> +
> +    if (inta->high < intb->low) {
> +        return -1;
> +    } else if (intb->high < inta->low) {
> +        return 1;
> +    } else {
> +        return 0;
> +    }
> +}
> +
> +static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
> +{
> +    QLIST_REMOVE(ep, next);
> +    g_tree_unref(ep->domain->mappings);

Here domain->mapping is unreferenced for each endpoint, while at [1]
below you only reference the domain->mappings if it's the first
endpoint.  Is that problematic?

> +    ep->domain = NULL;
> +}
> +
> +static VirtIOIOMMUEndpoint *virtio_iommu_get_endpoint(VirtIOIOMMU *s,
> +                                                      uint32_t ep_id)
> +{
> +    VirtIOIOMMUEndpoint *ep;
> +
> +    ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
> +    if (ep) {
> +        return ep;
> +    }
> +    if (!virtio_iommu_mr(s, ep_id)) {
> +        return NULL;
> +    }
> +    ep = g_malloc0(sizeof(*ep));
> +    ep->id = ep_id;
> +    trace_virtio_iommu_get_endpoint(ep_id);
> +    g_tree_insert(s->endpoints, GUINT_TO_POINTER(ep_id), ep);
> +    return ep;
> +}
> +
> +static void virtio_iommu_put_endpoint(gpointer data)
> +{
> +    VirtIOIOMMUEndpoint *ep = (VirtIOIOMMUEndpoint *)data;
> +
> +    assert(!ep->domain);
> +
> +    trace_virtio_iommu_put_endpoint(ep->id);
> +    g_free(ep);
> +}
> +
> +static VirtIOIOMMUDomain *virtio_iommu_get_domain(VirtIOIOMMU *s,
> +                                                  uint32_t domain_id)
> +{
> +    VirtIOIOMMUDomain *domain;
> +
> +    domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
> +    if (domain) {
> +        return domain;
> +    }
> +    domain = g_malloc0(sizeof(*domain));
> +    domain->id = domain_id;
> +    domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
> +                                   NULL, (GDestroyNotify)g_free,
> +                                   (GDestroyNotify)g_free);
> +    g_tree_insert(s->domains, GUINT_TO_POINTER(domain_id), domain);
> +    QLIST_INIT(&domain->endpoint_list);
> +    trace_virtio_iommu_get_domain(domain_id);
> +    return domain;
> +}
> +
> +static void virtio_iommu_put_domain(gpointer data)
> +{
> +    VirtIOIOMMUDomain *domain = (VirtIOIOMMUDomain *)data;
> +    VirtIOIOMMUEndpoint *iter, *tmp;
> +
> +    QLIST_FOREACH_SAFE(iter, &domain->endpoint_list, next, tmp) {
> +        virtio_iommu_detach_endpoint_from_domain(iter);
> +    }

Do you need to destroy the domain->mappings here?  Is it leaked?

> +    trace_virtio_iommu_put_domain(domain->id);
> +    g_free(domain);
> +}
> +
> +static AddressSpace *virtio_iommu_find_add_as(PCIBus *bus, void *opaque,
> +                                              int devfn)
> +{
> +    VirtIOIOMMU *s = opaque;
> +    IOMMUPciBus *sbus = g_hash_table_lookup(s->as_by_busptr, bus);
> +    static uint32_t mr_index;
> +    IOMMUDevice *sdev;
> +
> +    if (!sbus) {
> +        sbus = g_malloc0(sizeof(IOMMUPciBus) +
> +                         sizeof(IOMMUDevice *) * PCI_DEVFN_MAX);
> +        sbus->bus = bus;
> +        g_hash_table_insert(s->as_by_busptr, bus, sbus);
> +    }
> +
> +    sdev = sbus->pbdev[devfn];
> +    if (!sdev) {
> +        char *name = g_strdup_printf("%s-%d-%d",
> +                                     TYPE_VIRTIO_IOMMU_MEMORY_REGION,
> +                                     mr_index++, devfn);
> +        sdev = sbus->pbdev[devfn] = g_malloc0(sizeof(IOMMUDevice));
> +
> +        sdev->viommu = s;
> +        sdev->bus = bus;
> +        sdev->devfn = devfn;
> +
> +        trace_virtio_iommu_init_iommu_mr(name);
> +
> +        memory_region_init_iommu(&sdev->iommu_mr, sizeof(sdev->iommu_mr),
> +                                 TYPE_VIRTIO_IOMMU_MEMORY_REGION,
> +                                 OBJECT(s), name,
> +                                 UINT64_MAX);
> +        address_space_init(&sdev->as,
> +                           MEMORY_REGION(&sdev->iommu_mr), TYPE_VIRTIO_IOMMU);
> +        g_free(name);
> +    }
> +    return &sdev->as;
> +}
> +
>  static int virtio_iommu_attach(VirtIOIOMMU *s,
>                                 struct virtio_iommu_req_attach *req)
>  {
>      uint32_t domain_id = le32_to_cpu(req->domain);
>      uint32_t ep_id = le32_to_cpu(req->endpoint);
> +    VirtIOIOMMUDomain *domain;
> +    VirtIOIOMMUEndpoint *ep;
>  
>      trace_virtio_iommu_attach(domain_id, ep_id);
>  
> -    return VIRTIO_IOMMU_S_UNSUPP;
> +    ep = virtio_iommu_get_endpoint(s, ep_id);
> +    if (!ep) {
> +        return VIRTIO_IOMMU_S_NOENT;
> +    }
> +
> +    if (ep->domain) {
> +        VirtIOIOMMUDomain *previous_domain = ep->domain;
> +        /*
> +         * the device is already attached to a domain,
> +         * detach it first
> +         */
> +        virtio_iommu_detach_endpoint_from_domain(ep);
> +        if (QLIST_EMPTY(&previous_domain->endpoint_list)) {
> +            g_tree_remove(s->domains, GUINT_TO_POINTER(previous_domain->id));
> +        }
> +    }
> +
> +    domain = virtio_iommu_get_domain(s, domain_id);
> +    if (!QLIST_EMPTY(&domain->endpoint_list)) {
> +        g_tree_ref(domain->mappings);

[1]

> +    }
> +    QLIST_INSERT_HEAD(&domain->endpoint_list, ep, next);
> +
> +    ep->domain = domain;
> +
> +    return VIRTIO_IOMMU_S_OK;
>  }
>  
>  static int virtio_iommu_detach(VirtIOIOMMU *s,
> @@ -50,10 +267,34 @@ static int virtio_iommu_detach(VirtIOIOMMU *s,
>  {
>      uint32_t domain_id = le32_to_cpu(req->domain);
>      uint32_t ep_id = le32_to_cpu(req->endpoint);
> +    VirtIOIOMMUDomain *domain;
> +    VirtIOIOMMUEndpoint *ep;
>  
>      trace_virtio_iommu_detach(domain_id, ep_id);
>  
> -    return VIRTIO_IOMMU_S_UNSUPP;
> +    ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
> +    if (!ep) {
> +        return VIRTIO_IOMMU_S_NOENT;
> +    }
> +
> +    domain = ep->domain;
> +
> +    if (!domain || domain->id != domain_id) {
> +        return VIRTIO_IOMMU_S_INVAL;
> +    }
> +
> +    virtio_iommu_detach_endpoint_from_domain(ep);
> +
> +    /*
> +     * when the last EP is detached, simply remove the domain for
> +     * the domain list and destroy it. Note its mappings were already
> +     * freed by the ref count mechanism. Next operation involving
> +     * the same domain id will re-create one domain object.
> +     */
> +    if (QLIST_EMPTY(&domain->endpoint_list)) {
> +        g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
> +    }
> +    return VIRTIO_IOMMU_S_OK;
>  }
>  
>  static int virtio_iommu_map(VirtIOIOMMU *s,
> @@ -172,6 +413,27 @@ out:
>      }
>  }
>  
> +static IOMMUTLBEntry virtio_iommu_translate(IOMMUMemoryRegion *mr, hwaddr addr,
> +                                            IOMMUAccessFlags flag,
> +                                            int iommu_idx)
> +{
> +    IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
> +    uint32_t sid;
> +
> +    IOMMUTLBEntry entry = {
> +        .target_as = &address_space_memory,
> +        .iova = addr,
> +        .translated_addr = addr,
> +        .addr_mask = ~(hwaddr)0,
> +        .perm = IOMMU_NONE,
> +    };
> +
> +    sid = virtio_iommu_get_bdf(sdev);
> +
> +    trace_virtio_iommu_translate(mr->parent_obj.name, sid, addr, flag);
> +    return entry;
> +}
> +
>  static void virtio_iommu_get_config(VirtIODevice *vdev, uint8_t *config_data)
>  {
>      VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
> @@ -218,6 +480,13 @@ static const VMStateDescription vmstate_virtio_iommu_device = {
>      .unmigratable = 1,
>  };
>  
> +static gint int_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
> +{
> +    guint ua = GPOINTER_TO_UINT(a);
> +    guint ub = GPOINTER_TO_UINT(b);
> +    return (ua > ub) - (ua < ub);
> +}
> +
>  static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
>  {
>      VirtIODevice *vdev = VIRTIO_DEVICE(dev);
> @@ -226,6 +495,8 @@ static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
>      virtio_init(vdev, "virtio-iommu", VIRTIO_ID_IOMMU,
>                  sizeof(struct virtio_iommu_config));
>  
> +    memset(s->iommu_pcibus_by_bus_num, 0, sizeof(s->iommu_pcibus_by_bus_num));
> +
>      s->req_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE,
>                               virtio_iommu_handle_command);
>      s->event_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE, NULL);
> @@ -244,18 +515,43 @@ static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
>      virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MMIO);
>  
>      qemu_mutex_init(&s->mutex);
> +
> +    s->as_by_busptr = g_hash_table_new_full(NULL, NULL, NULL, g_free);
> +
> +    if (s->primary_bus) {
> +        pci_setup_iommu(s->primary_bus, virtio_iommu_find_add_as, s);
> +    } else {
> +        error_setg(errp, "VIRTIO-IOMMU is not attached to any PCI bus!");
> +    }
>  }
>  
>  static void virtio_iommu_device_unrealize(DeviceState *dev, Error **errp)
>  {
>      VirtIODevice *vdev = VIRTIO_DEVICE(dev);
> +    VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
> +
> +    g_tree_destroy(s->domains);
> +    g_tree_destroy(s->endpoints);
>  
>      virtio_cleanup(vdev);
>  }
>  
>  static void virtio_iommu_device_reset(VirtIODevice *vdev)
>  {
> +    VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
> +
>      trace_virtio_iommu_device_reset();
> +
> +    if (s->domains) {
> +        g_tree_destroy(s->domains);
> +    }
> +    if (s->endpoints) {
> +        g_tree_destroy(s->endpoints);
> +    }

Is it a must to free domians first then the endpoints here?

I see that virtio_iommu_put_domain() will unlink the domains and
endpoints, then in virtio_iommu_put_endpoint() you assert that
ep->domain is NULL.  It's fine but I'm a bit curious on why not call
virtio_iommu_detach_endpoint_from_domain() too when put the endpoint
then iiuc we don't even need this ordering constraint.  Though in
virtio_iommu_detach_endpoint_from_domain() we should also need:

  if (!ep->domain)
    return;

Otherwise it looks good to me.  Thanks,

> +    s->domains = g_tree_new_full((GCompareDataFunc)int_cmp,
> +                                 NULL, NULL, virtio_iommu_put_domain);
> +    s->endpoints = g_tree_new_full((GCompareDataFunc)int_cmp,
> +                                   NULL, NULL, virtio_iommu_put_endpoint);
>  }
>  
>  static void virtio_iommu_set_status(VirtIODevice *vdev, uint8_t status)
> @@ -301,6 +597,14 @@ static void virtio_iommu_class_init(ObjectClass *klass, void *data)
>      vdc->vmsd = &vmstate_virtio_iommu_device;
>  }
>  
> +static void virtio_iommu_memory_region_class_init(ObjectClass *klass,
> +                                                  void *data)
> +{
> +    IOMMUMemoryRegionClass *imrc = IOMMU_MEMORY_REGION_CLASS(klass);
> +
> +    imrc->translate = virtio_iommu_translate;
> +}
> +
>  static const TypeInfo virtio_iommu_info = {
>      .name = TYPE_VIRTIO_IOMMU,
>      .parent = TYPE_VIRTIO_DEVICE,
> @@ -309,9 +613,16 @@ static const TypeInfo virtio_iommu_info = {
>      .class_init = virtio_iommu_class_init,
>  };
>  
> +static const TypeInfo virtio_iommu_memory_region_info = {
> +    .parent = TYPE_IOMMU_MEMORY_REGION,
> +    .name = TYPE_VIRTIO_IOMMU_MEMORY_REGION,
> +    .class_init = virtio_iommu_memory_region_class_init,
> +};
> +
>  static void virtio_register_types(void)
>  {
>      type_register_static(&virtio_iommu_info);
> +    type_register_static(&virtio_iommu_memory_region_info);
>  }
>  
>  type_init(virtio_register_types)
> diff --git a/include/hw/virtio/virtio-iommu.h b/include/hw/virtio/virtio-iommu.h
> index 041f2c9390..2a2c2ecf83 100644
> --- a/include/hw/virtio/virtio-iommu.h
> +++ b/include/hw/virtio/virtio-iommu.h
> @@ -28,6 +28,8 @@
>  #define VIRTIO_IOMMU(obj) \
>          OBJECT_CHECK(VirtIOIOMMU, (obj), TYPE_VIRTIO_IOMMU)
>  
> +#define TYPE_VIRTIO_IOMMU_MEMORY_REGION "virtio-iommu-memory-region"
> +
>  typedef struct IOMMUDevice {
>      void         *viommu;
>      PCIBus       *bus;
> @@ -48,6 +50,7 @@ typedef struct VirtIOIOMMU {
>      struct virtio_iommu_config config;
>      uint64_t features;
>      GHashTable *as_by_busptr;
> +    IOMMUPciBus *iommu_pcibus_by_bus_num[PCI_BUS_MAX];
>      PCIBus *primary_bus;
>      GTree *domains;
>      QemuMutex mutex;
> -- 
> 2.20.1
>
Eric Auger Feb. 3, 2020, 2:59 p.m. UTC | #2
Hi Peter,

On 2/3/20 2:49 PM, Peter Xu wrote:
> On Sat, Jan 25, 2020 at 06:19:48PM +0100, Eric Auger wrote:
>> This patch implements the endpoint attach/detach to/from
>> a domain.
>>
>> Domain and endpoint internal datatypes are introduced.
>> Both are stored in RB trees. The domain owns a list of
>> endpoints attached to it. Also helpers to get/put
>> end points and domains are introduced.
>>
>> As for the IOMMU memory regions, a callback is called on
>> PCI bus enumeration that initializes for a given device
>> on the bus hierarchy an IOMMU memory region. The PCI bus
>> hierarchy is stored locally in IOMMUPciBus and IOMMUDevice
>> objects.
>>
>> At the time of the enumeration, the bus number may not be
>> computed yet.
>>
>> So operations that will need to retrieve the IOMMUdevice
>> and its IOMMU memory region from the bus number and devfn,
>> once the bus number is garanteed to be frozen, use an array
>> of IOMMUPciBus, lazily populated.
>>
>> Signed-off-by: Eric Auger <eric.auger@redhat.com>
>>
>> ---
>>
>> v12 -> v13:
>> - squashed v12 4, 5, 6 into this patch
>> - rename virtio_iommu_get_sid into virtio_iommu_get_bdf
>>
>> v11 -> v12:
>> - check the device is protected by the iommu on attach
>> - on detach, check the domain id the device is attached to matches
>>   the one used in the detach command
>> - fix mapping ref counter and destroy the domain when no end-points
>>   are attached anymore
>> ---
>>  hw/virtio/trace-events           |   6 +
>>  hw/virtio/virtio-iommu.c         | 315 ++++++++++++++++++++++++++++++-
>>  include/hw/virtio/virtio-iommu.h |   3 +
>>  3 files changed, 322 insertions(+), 2 deletions(-)
>>
>> diff --git a/hw/virtio/trace-events b/hw/virtio/trace-events
>> index f7141aa2f6..15595f8cd7 100644
>> --- a/hw/virtio/trace-events
>> +++ b/hw/virtio/trace-events
>> @@ -64,3 +64,9 @@ virtio_iommu_attach(uint32_t domain_id, uint32_t ep_id) "domain=%d endpoint=%d"
>>  virtio_iommu_detach(uint32_t domain_id, uint32_t ep_id) "domain=%d endpoint=%d"
>>  virtio_iommu_map(uint32_t domain_id, uint64_t virt_start, uint64_t virt_end, uint64_t phys_start, uint32_t flags) "domain=%d virt_start=0x%"PRIx64" virt_end=0x%"PRIx64 " phys_start=0x%"PRIx64" flags=%d"
>>  virtio_iommu_unmap(uint32_t domain_id, uint64_t virt_start, uint64_t virt_end) "domain=%d virt_start=0x%"PRIx64" virt_end=0x%"PRIx64
>> +virtio_iommu_translate(const char *name, uint32_t rid, uint64_t iova, int flag) "mr=%s rid=%d addr=0x%"PRIx64" flag=%d"
>> +virtio_iommu_init_iommu_mr(char *iommu_mr) "init %s"
>> +virtio_iommu_get_endpoint(uint32_t ep_id) "Alloc endpoint=%d"
>> +virtio_iommu_put_endpoint(uint32_t ep_id) "Free endpoint=%d"
>> +virtio_iommu_get_domain(uint32_t domain_id) "Alloc domain=%d"
>> +virtio_iommu_put_domain(uint32_t domain_id) "Free domain=%d"
>> diff --git a/hw/virtio/virtio-iommu.c b/hw/virtio/virtio-iommu.c
>> index 9d1b997df7..e5cc94138b 100644
>> --- a/hw/virtio/virtio-iommu.c
>> +++ b/hw/virtio/virtio-iommu.c
>> @@ -23,6 +23,8 @@
>>  #include "hw/qdev-properties.h"
>>  #include "hw/virtio/virtio.h"
>>  #include "sysemu/kvm.h"
>> +#include "qapi/error.h"
>> +#include "qemu/error-report.h"
>>  #include "trace.h"
>>  
>>  #include "standard-headers/linux/virtio_ids.h"
>> @@ -30,19 +32,234 @@
>>  #include "hw/virtio/virtio-bus.h"
>>  #include "hw/virtio/virtio-access.h"
>>  #include "hw/virtio/virtio-iommu.h"
>> +#include "hw/pci/pci_bus.h"
>> +#include "hw/pci/pci.h"
>>  
>>  /* Max size */
>>  #define VIOMMU_DEFAULT_QUEUE_SIZE 256
>>  
>> +typedef struct VirtIOIOMMUDomain {
>> +    uint32_t id;
>> +    GTree *mappings;
>> +    QLIST_HEAD(, VirtIOIOMMUEndpoint) endpoint_list;
>> +} VirtIOIOMMUDomain;
>> +
>> +typedef struct VirtIOIOMMUEndpoint {
>> +    uint32_t id;
>> +    VirtIOIOMMUDomain *domain;
>> +    QLIST_ENTRY(VirtIOIOMMUEndpoint) next;
>> +} VirtIOIOMMUEndpoint;
>> +
>> +typedef struct VirtIOIOMMUInterval {
>> +    uint64_t low;
>> +    uint64_t high;
>> +} VirtIOIOMMUInterval;
>> +
>> +static inline uint16_t virtio_iommu_get_bdf(IOMMUDevice *dev)
>> +{
>> +    return PCI_BUILD_BDF(pci_bus_num(dev->bus), dev->devfn);
>> +}
>> +
>> +/**
>> + * The bus number is used for lookup when SID based operations occur.
>> + * In that case we lazily populate the IOMMUPciBus array from the bus hash
>> + * table. At the time the IOMMUPciBus is created (iommu_find_add_as), the bus
>> + * numbers may not be always initialized yet.
>> + */
>> +static IOMMUPciBus *iommu_find_iommu_pcibus(VirtIOIOMMU *s, uint8_t bus_num)
>> +{
>> +    IOMMUPciBus *iommu_pci_bus = s->iommu_pcibus_by_bus_num[bus_num];
>> +
>> +    if (!iommu_pci_bus) {
>> +        GHashTableIter iter;
>> +
>> +        g_hash_table_iter_init(&iter, s->as_by_busptr);
>> +        while (g_hash_table_iter_next(&iter, NULL, (void **)&iommu_pci_bus)) {
>> +            if (pci_bus_num(iommu_pci_bus->bus) == bus_num) {
>> +                s->iommu_pcibus_by_bus_num[bus_num] = iommu_pci_bus;
>> +                return iommu_pci_bus;
>> +            }
>> +        }
>> +        return NULL;
>> +    }
>> +    return iommu_pci_bus;
>> +}
>> +
>> +static IOMMUMemoryRegion *virtio_iommu_mr(VirtIOIOMMU *s, uint32_t sid)
>> +{
>> +    uint8_t bus_n, devfn;
>> +    IOMMUPciBus *iommu_pci_bus;
>> +    IOMMUDevice *dev;
>> +
>> +    bus_n = PCI_BUS_NUM(sid);
>> +    iommu_pci_bus = iommu_find_iommu_pcibus(s, bus_n);
>> +    if (iommu_pci_bus) {
>> +        devfn = sid & PCI_DEVFN_MAX;
>> +        dev = iommu_pci_bus->pbdev[devfn];
>> +        if (dev) {
>> +            return &dev->iommu_mr;
>> +        }
>> +    }
>> +    return NULL;
>> +}
>> +
>> +static gint interval_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
>> +{
>> +    VirtIOIOMMUInterval *inta = (VirtIOIOMMUInterval *)a;
>> +    VirtIOIOMMUInterval *intb = (VirtIOIOMMUInterval *)b;
>> +
>> +    if (inta->high < intb->low) {
>> +        return -1;
>> +    } else if (intb->high < inta->low) {
>> +        return 1;
>> +    } else {
>> +        return 0;
>> +    }
>> +}
>> +
>> +static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
>> +{
>> +    QLIST_REMOVE(ep, next);
>> +    g_tree_unref(ep->domain->mappings);
> 
> Here domain->mapping is unreferenced for each endpoint, while at [1]
> below you only reference the domain->mappings if it's the first
> endpoint.  Is that problematic?
in [1] I take a ref to the domain->mappings if it is *not* the 1st
endpoint. This aims at deleting the mappings gtree when the last EP is
detached from the domain.

This fixes the issue reported by Jean in:
https://patchwork.kernel.org/patch/11258267/#23046313
> 
>> +    ep->domain = NULL;
>> +}
>> +
>> +static VirtIOIOMMUEndpoint *virtio_iommu_get_endpoint(VirtIOIOMMU *s,
>> +                                                      uint32_t ep_id)
>> +{
>> +    VirtIOIOMMUEndpoint *ep;
>> +
>> +    ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
>> +    if (ep) {
>> +        return ep;
>> +    }
>> +    if (!virtio_iommu_mr(s, ep_id)) {
>> +        return NULL;
>> +    }
>> +    ep = g_malloc0(sizeof(*ep));
>> +    ep->id = ep_id;
>> +    trace_virtio_iommu_get_endpoint(ep_id);
>> +    g_tree_insert(s->endpoints, GUINT_TO_POINTER(ep_id), ep);
>> +    return ep;
>> +}
>> +
>> +static void virtio_iommu_put_endpoint(gpointer data)
>> +{
>> +    VirtIOIOMMUEndpoint *ep = (VirtIOIOMMUEndpoint *)data;
>> +
>> +    assert(!ep->domain);
>> +
>> +    trace_virtio_iommu_put_endpoint(ep->id);
>> +    g_free(ep);
>> +}
>> +
>> +static VirtIOIOMMUDomain *virtio_iommu_get_domain(VirtIOIOMMU *s,
>> +                                                  uint32_t domain_id)
>> +{
>> +    VirtIOIOMMUDomain *domain;
>> +
>> +    domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
>> +    if (domain) {
>> +        return domain;
>> +    }
>> +    domain = g_malloc0(sizeof(*domain));
>> +    domain->id = domain_id;
>> +    domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
>> +                                   NULL, (GDestroyNotify)g_free,
>> +                                   (GDestroyNotify)g_free);
>> +    g_tree_insert(s->domains, GUINT_TO_POINTER(domain_id), domain);
>> +    QLIST_INIT(&domain->endpoint_list);
>> +    trace_virtio_iommu_get_domain(domain_id);
>> +    return domain;
>> +}
>> +
>> +static void virtio_iommu_put_domain(gpointer data)
>> +{
>> +    VirtIOIOMMUDomain *domain = (VirtIOIOMMUDomain *)data;
>> +    VirtIOIOMMUEndpoint *iter, *tmp;
>> +
>> +    QLIST_FOREACH_SAFE(iter, &domain->endpoint_list, next, tmp) {
>> +        virtio_iommu_detach_endpoint_from_domain(iter);
>> +    }
> 
> Do you need to destroy the domain->mappings here?  Is it leaked?
AFIU as we detach all EPs in the loop above, the whole "mappings" gtree
is destroyed so there is no leak.
> 
>> +    trace_virtio_iommu_put_domain(domain->id);
>> +    g_free(domain);
>> +}
>> +
>> +static AddressSpace *virtio_iommu_find_add_as(PCIBus *bus, void *opaque,
>> +                                              int devfn)
>> +{
>> +    VirtIOIOMMU *s = opaque;
>> +    IOMMUPciBus *sbus = g_hash_table_lookup(s->as_by_busptr, bus);
>> +    static uint32_t mr_index;
>> +    IOMMUDevice *sdev;
>> +
>> +    if (!sbus) {
>> +        sbus = g_malloc0(sizeof(IOMMUPciBus) +
>> +                         sizeof(IOMMUDevice *) * PCI_DEVFN_MAX);
>> +        sbus->bus = bus;
>> +        g_hash_table_insert(s->as_by_busptr, bus, sbus);
>> +    }
>> +
>> +    sdev = sbus->pbdev[devfn];
>> +    if (!sdev) {
>> +        char *name = g_strdup_printf("%s-%d-%d",
>> +                                     TYPE_VIRTIO_IOMMU_MEMORY_REGION,
>> +                                     mr_index++, devfn);
>> +        sdev = sbus->pbdev[devfn] = g_malloc0(sizeof(IOMMUDevice));
>> +
>> +        sdev->viommu = s;
>> +        sdev->bus = bus;
>> +        sdev->devfn = devfn;
>> +
>> +        trace_virtio_iommu_init_iommu_mr(name);
>> +
>> +        memory_region_init_iommu(&sdev->iommu_mr, sizeof(sdev->iommu_mr),
>> +                                 TYPE_VIRTIO_IOMMU_MEMORY_REGION,
>> +                                 OBJECT(s), name,
>> +                                 UINT64_MAX);
>> +        address_space_init(&sdev->as,
>> +                           MEMORY_REGION(&sdev->iommu_mr), TYPE_VIRTIO_IOMMU);
>> +        g_free(name);
>> +    }
>> +    return &sdev->as;
>> +}
>> +
>>  static int virtio_iommu_attach(VirtIOIOMMU *s,
>>                                 struct virtio_iommu_req_attach *req)
>>  {
>>      uint32_t domain_id = le32_to_cpu(req->domain);
>>      uint32_t ep_id = le32_to_cpu(req->endpoint);
>> +    VirtIOIOMMUDomain *domain;
>> +    VirtIOIOMMUEndpoint *ep;
>>  
>>      trace_virtio_iommu_attach(domain_id, ep_id);
>>  
>> -    return VIRTIO_IOMMU_S_UNSUPP;
>> +    ep = virtio_iommu_get_endpoint(s, ep_id);
>> +    if (!ep) {
>> +        return VIRTIO_IOMMU_S_NOENT;
>> +    }
>> +
>> +    if (ep->domain) {
>> +        VirtIOIOMMUDomain *previous_domain = ep->domain;
>> +        /*
>> +         * the device is already attached to a domain,
>> +         * detach it first
>> +         */
>> +        virtio_iommu_detach_endpoint_from_domain(ep);
>> +        if (QLIST_EMPTY(&previous_domain->endpoint_list)) {
>> +            g_tree_remove(s->domains, GUINT_TO_POINTER(previous_domain->id));
>> +        }
>> +    }
>> +
>> +    domain = virtio_iommu_get_domain(s, domain_id);
>> +    if (!QLIST_EMPTY(&domain->endpoint_list)) {
>> +        g_tree_ref(domain->mappings);
> 
> [1]
!QLIST_EMPTY
> 
>> +    }
>> +    QLIST_INSERT_HEAD(&domain->endpoint_list, ep, next);
>> +
>> +    ep->domain = domain;
>> +
>> +    return VIRTIO_IOMMU_S_OK;
>>  }
>>  
>>  static int virtio_iommu_detach(VirtIOIOMMU *s,
>> @@ -50,10 +267,34 @@ static int virtio_iommu_detach(VirtIOIOMMU *s,
>>  {
>>      uint32_t domain_id = le32_to_cpu(req->domain);
>>      uint32_t ep_id = le32_to_cpu(req->endpoint);
>> +    VirtIOIOMMUDomain *domain;
>> +    VirtIOIOMMUEndpoint *ep;
>>  
>>      trace_virtio_iommu_detach(domain_id, ep_id);
>>  
>> -    return VIRTIO_IOMMU_S_UNSUPP;
>> +    ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
>> +    if (!ep) {
>> +        return VIRTIO_IOMMU_S_NOENT;
>> +    }
>> +
>> +    domain = ep->domain;
>> +
>> +    if (!domain || domain->id != domain_id) {
>> +        return VIRTIO_IOMMU_S_INVAL;
>> +    }
>> +
>> +    virtio_iommu_detach_endpoint_from_domain(ep);
>> +
>> +    /*
>> +     * when the last EP is detached, simply remove the domain for
>> +     * the domain list and destroy it. Note its mappings were already
>> +     * freed by the ref count mechanism. Next operation involving
>> +     * the same domain id will re-create one domain object.
>> +     */
>> +    if (QLIST_EMPTY(&domain->endpoint_list)) {
>> +        g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
>> +    }
>> +    return VIRTIO_IOMMU_S_OK;
>>  }
>>  
>>  static int virtio_iommu_map(VirtIOIOMMU *s,
>> @@ -172,6 +413,27 @@ out:
>>      }
>>  }
>>  
>> +static IOMMUTLBEntry virtio_iommu_translate(IOMMUMemoryRegion *mr, hwaddr addr,
>> +                                            IOMMUAccessFlags flag,
>> +                                            int iommu_idx)
>> +{
>> +    IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
>> +    uint32_t sid;
>> +
>> +    IOMMUTLBEntry entry = {
>> +        .target_as = &address_space_memory,
>> +        .iova = addr,
>> +        .translated_addr = addr,
>> +        .addr_mask = ~(hwaddr)0,
>> +        .perm = IOMMU_NONE,
>> +    };
>> +
>> +    sid = virtio_iommu_get_bdf(sdev);
>> +
>> +    trace_virtio_iommu_translate(mr->parent_obj.name, sid, addr, flag);
>> +    return entry;
>> +}
>> +
>>  static void virtio_iommu_get_config(VirtIODevice *vdev, uint8_t *config_data)
>>  {
>>      VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
>> @@ -218,6 +480,13 @@ static const VMStateDescription vmstate_virtio_iommu_device = {
>>      .unmigratable = 1,
>>  };
>>  
>> +static gint int_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
>> +{
>> +    guint ua = GPOINTER_TO_UINT(a);
>> +    guint ub = GPOINTER_TO_UINT(b);
>> +    return (ua > ub) - (ua < ub);
>> +}
>> +
>>  static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
>>  {
>>      VirtIODevice *vdev = VIRTIO_DEVICE(dev);
>> @@ -226,6 +495,8 @@ static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
>>      virtio_init(vdev, "virtio-iommu", VIRTIO_ID_IOMMU,
>>                  sizeof(struct virtio_iommu_config));
>>  
>> +    memset(s->iommu_pcibus_by_bus_num, 0, sizeof(s->iommu_pcibus_by_bus_num));
>> +
>>      s->req_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE,
>>                               virtio_iommu_handle_command);
>>      s->event_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE, NULL);
>> @@ -244,18 +515,43 @@ static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
>>      virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MMIO);
>>  
>>      qemu_mutex_init(&s->mutex);
>> +
>> +    s->as_by_busptr = g_hash_table_new_full(NULL, NULL, NULL, g_free);
>> +
>> +    if (s->primary_bus) {
>> +        pci_setup_iommu(s->primary_bus, virtio_iommu_find_add_as, s);
>> +    } else {
>> +        error_setg(errp, "VIRTIO-IOMMU is not attached to any PCI bus!");
>> +    }
>>  }
>>  
>>  static void virtio_iommu_device_unrealize(DeviceState *dev, Error **errp)
>>  {
>>      VirtIODevice *vdev = VIRTIO_DEVICE(dev);
>> +    VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
>> +
>> +    g_tree_destroy(s->domains);
>> +    g_tree_destroy(s->endpoints);
>>  
>>      virtio_cleanup(vdev);
>>  }
>>  
>>  static void virtio_iommu_device_reset(VirtIODevice *vdev)
>>  {
>> +    VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
>> +
>>      trace_virtio_iommu_device_reset();
>> +
>> +    if (s->domains) {
>> +        g_tree_destroy(s->domains);
>> +    }
>> +    if (s->endpoints) {
>> +        g_tree_destroy(s->endpoints);
>> +    }
> 
> Is it a must to free domians first then the endpoints here?
> 
> I see that virtio_iommu_put_domain() will unlink the domains and
> endpoints, then in virtio_iommu_put_endpoint() you assert that
> ep->domain is NULL.  It's fine but I'm a bit curious on why not call
> virtio_iommu_detach_endpoint_from_domain() too when put the endpoint
> then iiuc we don't even need this ordering constraint.  Though in
> virtio_iommu_detach_endpoint_from_domain() we should also need:

Yes today an EP is meant to be deleted if it is detached from any domain.

I may add virtio_iommu_detach_endpoint_from_domain in put_domain though.

> 
>   if (!ep->domain)
>     return;
> 
> Otherwise it looks good to me.  Thanks,
Thanks

Eric
> 
>> +    s->domains = g_tree_new_full((GCompareDataFunc)int_cmp,
>> +                                 NULL, NULL, virtio_iommu_put_domain);
>> +    s->endpoints = g_tree_new_full((GCompareDataFunc)int_cmp,
>> +                                   NULL, NULL, virtio_iommu_put_endpoint);
>>  }
>>  
>>  static void virtio_iommu_set_status(VirtIODevice *vdev, uint8_t status)
>> @@ -301,6 +597,14 @@ static void virtio_iommu_class_init(ObjectClass *klass, void *data)
>>      vdc->vmsd = &vmstate_virtio_iommu_device;
>>  }
>>  
>> +static void virtio_iommu_memory_region_class_init(ObjectClass *klass,
>> +                                                  void *data)
>> +{
>> +    IOMMUMemoryRegionClass *imrc = IOMMU_MEMORY_REGION_CLASS(klass);
>> +
>> +    imrc->translate = virtio_iommu_translate;
>> +}
>> +
>>  static const TypeInfo virtio_iommu_info = {
>>      .name = TYPE_VIRTIO_IOMMU,
>>      .parent = TYPE_VIRTIO_DEVICE,
>> @@ -309,9 +613,16 @@ static const TypeInfo virtio_iommu_info = {
>>      .class_init = virtio_iommu_class_init,
>>  };
>>  
>> +static const TypeInfo virtio_iommu_memory_region_info = {
>> +    .parent = TYPE_IOMMU_MEMORY_REGION,
>> +    .name = TYPE_VIRTIO_IOMMU_MEMORY_REGION,
>> +    .class_init = virtio_iommu_memory_region_class_init,
>> +};
>> +
>>  static void virtio_register_types(void)
>>  {
>>      type_register_static(&virtio_iommu_info);
>> +    type_register_static(&virtio_iommu_memory_region_info);
>>  }
>>  
>>  type_init(virtio_register_types)
>> diff --git a/include/hw/virtio/virtio-iommu.h b/include/hw/virtio/virtio-iommu.h
>> index 041f2c9390..2a2c2ecf83 100644
>> --- a/include/hw/virtio/virtio-iommu.h
>> +++ b/include/hw/virtio/virtio-iommu.h
>> @@ -28,6 +28,8 @@
>>  #define VIRTIO_IOMMU(obj) \
>>          OBJECT_CHECK(VirtIOIOMMU, (obj), TYPE_VIRTIO_IOMMU)
>>  
>> +#define TYPE_VIRTIO_IOMMU_MEMORY_REGION "virtio-iommu-memory-region"
>> +
>>  typedef struct IOMMUDevice {
>>      void         *viommu;
>>      PCIBus       *bus;
>> @@ -48,6 +50,7 @@ typedef struct VirtIOIOMMU {
>>      struct virtio_iommu_config config;
>>      uint64_t features;
>>      GHashTable *as_by_busptr;
>> +    IOMMUPciBus *iommu_pcibus_by_bus_num[PCI_BUS_MAX];
>>      PCIBus *primary_bus;
>>      GTree *domains;
>>      QemuMutex mutex;
>> -- 
>> 2.20.1
>>
>
Peter Xu Feb. 3, 2020, 3:19 p.m. UTC | #3
On Mon, Feb 03, 2020 at 03:59:00PM +0100, Auger Eric wrote:

[...]

> >> +static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
> >> +{
> >> +    QLIST_REMOVE(ep, next);
> >> +    g_tree_unref(ep->domain->mappings);
> > 
> > Here domain->mapping is unreferenced for each endpoint, while at [1]
> > below you only reference the domain->mappings if it's the first
> > endpoint.  Is that problematic?
> in [1] I take a ref to the domain->mappings if it is *not* the 1st
> endpoint. This aims at deleting the mappings gtree when the last EP is
> detached from the domain.
> 
> This fixes the issue reported by Jean in:
> https://patchwork.kernel.org/patch/11258267/#23046313

Ah OK. :)

However this is tricky.  How about do explicit g_tree_destroy() in
virtio_iommu_detach() when it's the last endpoint?  I see that you
have:

    /*
     * when the last EP is detached, simply remove the domain for
     * the domain list and destroy it. Note its mappings were already
     * freed by the ref count mechanism. Next operation involving
     * the same domain id will re-create one domain object.
     */
    if (QLIST_EMPTY(&domain->endpoint_list)) {
        g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
    }

Then it becomes:

    if (QLIST_EMPTY(&domain->endpoint_list)) {
        g_tree_destroy(domain->mappings);
        g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
    }

And also remove the trick in attach() so you take the domain ref
unconditionally.  Would that work?

Thanks,
Eric Auger Feb. 3, 2020, 5:46 p.m. UTC | #4
Hi Peter,

On 2/3/20 4:19 PM, Peter Xu wrote:
> On Mon, Feb 03, 2020 at 03:59:00PM +0100, Auger Eric wrote:
> 
> [...]
> 
>>>> +static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
>>>> +{
>>>> +    QLIST_REMOVE(ep, next);
>>>> +    g_tree_unref(ep->domain->mappings);
>>>
>>> Here domain->mapping is unreferenced for each endpoint, while at [1]
>>> below you only reference the domain->mappings if it's the first
>>> endpoint.  Is that problematic?
>> in [1] I take a ref to the domain->mappings if it is *not* the 1st
>> endpoint. This aims at deleting the mappings gtree when the last EP is
>> detached from the domain.
>>
>> This fixes the issue reported by Jean in:
>> https://patchwork.kernel.org/patch/11258267/#23046313
> 
> Ah OK. :)
> 
> However this is tricky.  How about do explicit g_tree_destroy() in
> virtio_iommu_detach() when it's the last endpoint?  I see that you
> have:
> 
>     /*
>      * when the last EP is detached, simply remove the domain for
>      * the domain list and destroy it. Note its mappings were already
>      * freed by the ref count mechanism. Next operation involving
>      * the same domain id will re-create one domain object.
>      */
>     if (QLIST_EMPTY(&domain->endpoint_list)) {
>         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
>     }
> 
> Then it becomes:
> 
>     if (QLIST_EMPTY(&domain->endpoint_list)) {
>         g_tree_destroy(domain->mappings);
>         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
>     }
> 
> And also remove the trick in attach() so you take the domain ref
> unconditionally.  Would that work?
Yes I think so. On the other hand this ref counting mechanism is also
made for that purpose of destroying objects without being forced to
explicitly call the destroy function.

Thanks

Eric
> 
> Thanks,
>
Peter Xu Feb. 3, 2020, 6:19 p.m. UTC | #5
On Mon, Feb 03, 2020 at 06:46:36PM +0100, Auger Eric wrote:
> Hi Peter,
> 
> On 2/3/20 4:19 PM, Peter Xu wrote:
> > On Mon, Feb 03, 2020 at 03:59:00PM +0100, Auger Eric wrote:
> > 
> > [...]
> > 
> >>>> +static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
> >>>> +{
> >>>> +    QLIST_REMOVE(ep, next);
> >>>> +    g_tree_unref(ep->domain->mappings);
> >>>
> >>> Here domain->mapping is unreferenced for each endpoint, while at [1]
> >>> below you only reference the domain->mappings if it's the first
> >>> endpoint.  Is that problematic?
> >> in [1] I take a ref to the domain->mappings if it is *not* the 1st
> >> endpoint. This aims at deleting the mappings gtree when the last EP is
> >> detached from the domain.
> >>
> >> This fixes the issue reported by Jean in:
> >> https://patchwork.kernel.org/patch/11258267/#23046313
> > 
> > Ah OK. :)
> > 
> > However this is tricky.  How about do explicit g_tree_destroy() in
> > virtio_iommu_detach() when it's the last endpoint?  I see that you
> > have:
> > 
> >     /*
> >      * when the last EP is detached, simply remove the domain for
> >      * the domain list and destroy it. Note its mappings were already
> >      * freed by the ref count mechanism. Next operation involving
> >      * the same domain id will re-create one domain object.
> >      */
> >     if (QLIST_EMPTY(&domain->endpoint_list)) {
> >         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
> >     }
> > 
> > Then it becomes:
> > 
> >     if (QLIST_EMPTY(&domain->endpoint_list)) {
> >         g_tree_destroy(domain->mappings);
> >         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
> >     }
> > 
> > And also remove the trick in attach() so you take the domain ref
> > unconditionally.  Would that work?
> Yes I think so. On the other hand this ref counting mechanism is also
> made for that purpose of destroying objects without being forced to
> explicitly call the destroy function.

IMHO that's two different things.  g_tree_destroy() should be the same
as g_tree_unref() here when the tree is empty.  It's really a matter
of easy reading of code:

void
g_tree_destroy (GTree *tree)
{
  g_return_if_fail (tree != NULL);

  g_tree_remove_all (tree);
  g_tree_unref (tree);
}

What we really changed here is to allow the ref/unref to be clearly
paired, i.e., for each EP it'll ref once and unref once.  The prvious
solution has the trick in that the 1st EP don't ref, the latter EPs
ref, and when the domain quits it doesn't unref to match the first
ref.  It's error prone to me.  Then, if we can do it in the paired way
easily, I don't see why not...

Thanks,
Eric Auger Feb. 4, 2020, 12:26 p.m. UTC | #6
Hi Peter,

On 2/3/20 7:19 PM, Peter Xu wrote:
> On Mon, Feb 03, 2020 at 06:46:36PM +0100, Auger Eric wrote:
>> Hi Peter,
>>
>> On 2/3/20 4:19 PM, Peter Xu wrote:
>>> On Mon, Feb 03, 2020 at 03:59:00PM +0100, Auger Eric wrote:
>>>
>>> [...]
>>>
>>>>>> +static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
>>>>>> +{
>>>>>> +    QLIST_REMOVE(ep, next);
>>>>>> +    g_tree_unref(ep->domain->mappings);
>>>>>
>>>>> Here domain->mapping is unreferenced for each endpoint, while at [1]
>>>>> below you only reference the domain->mappings if it's the first
>>>>> endpoint.  Is that problematic?
>>>> in [1] I take a ref to the domain->mappings if it is *not* the 1st
>>>> endpoint. This aims at deleting the mappings gtree when the last EP is
>>>> detached from the domain.
>>>>
>>>> This fixes the issue reported by Jean in:
>>>> https://patchwork.kernel.org/patch/11258267/#23046313
>>>
>>> Ah OK. :)
>>>
>>> However this is tricky.  How about do explicit g_tree_destroy() in
>>> virtio_iommu_detach() when it's the last endpoint?  I see that you
>>> have:
>>>
>>>     /*
>>>      * when the last EP is detached, simply remove the domain for
>>>      * the domain list and destroy it. Note its mappings were already
>>>      * freed by the ref count mechanism. Next operation involving
>>>      * the same domain id will re-create one domain object.
>>>      */
>>>     if (QLIST_EMPTY(&domain->endpoint_list)) {
>>>         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
>>>     }
>>>
>>> Then it becomes:
>>>
>>>     if (QLIST_EMPTY(&domain->endpoint_list)) {
>>>         g_tree_destroy(domain->mappings);
>>>         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
>>>     }
>>>
>>> And also remove the trick in attach() so you take the domain ref
>>> unconditionally.  Would that work?
>> Yes I think so. On the other hand this ref counting mechanism is also
>> made for that purpose of destroying objects without being forced to
>> explicitly call the destroy function.
> 
> IMHO that's two different things.  g_tree_destroy() should be the same
> as g_tree_unref() here when the tree is empty.  It's really a matter
> of easy reading of code:
> 
> void
> g_tree_destroy (GTree *tree)
> {
>   g_return_if_fail (tree != NULL);
> 
>   g_tree_remove_all (tree);
>   g_tree_unref (tree);
> }
> 
> What we really changed here is to allow the ref/unref to be clearly
> paired, i.e., for each EP it'll ref once and unref once.  The prvious
> solution has the trick in that the 1st EP don't ref, the latter EPs
> ref, and when the domain quits it doesn't unref to match the first
> ref.  It's error prone to me.  Then, if we can do it in the paired way
> easily, I don't see why not...

OK. I will respin according to your suggestion.

Thanks

Eric
> 
> Thanks,
>
diff mbox series

Patch

diff --git a/hw/virtio/trace-events b/hw/virtio/trace-events
index f7141aa2f6..15595f8cd7 100644
--- a/hw/virtio/trace-events
+++ b/hw/virtio/trace-events
@@ -64,3 +64,9 @@  virtio_iommu_attach(uint32_t domain_id, uint32_t ep_id) "domain=%d endpoint=%d"
 virtio_iommu_detach(uint32_t domain_id, uint32_t ep_id) "domain=%d endpoint=%d"
 virtio_iommu_map(uint32_t domain_id, uint64_t virt_start, uint64_t virt_end, uint64_t phys_start, uint32_t flags) "domain=%d virt_start=0x%"PRIx64" virt_end=0x%"PRIx64 " phys_start=0x%"PRIx64" flags=%d"
 virtio_iommu_unmap(uint32_t domain_id, uint64_t virt_start, uint64_t virt_end) "domain=%d virt_start=0x%"PRIx64" virt_end=0x%"PRIx64
+virtio_iommu_translate(const char *name, uint32_t rid, uint64_t iova, int flag) "mr=%s rid=%d addr=0x%"PRIx64" flag=%d"
+virtio_iommu_init_iommu_mr(char *iommu_mr) "init %s"
+virtio_iommu_get_endpoint(uint32_t ep_id) "Alloc endpoint=%d"
+virtio_iommu_put_endpoint(uint32_t ep_id) "Free endpoint=%d"
+virtio_iommu_get_domain(uint32_t domain_id) "Alloc domain=%d"
+virtio_iommu_put_domain(uint32_t domain_id) "Free domain=%d"
diff --git a/hw/virtio/virtio-iommu.c b/hw/virtio/virtio-iommu.c
index 9d1b997df7..e5cc94138b 100644
--- a/hw/virtio/virtio-iommu.c
+++ b/hw/virtio/virtio-iommu.c
@@ -23,6 +23,8 @@ 
 #include "hw/qdev-properties.h"
 #include "hw/virtio/virtio.h"
 #include "sysemu/kvm.h"
+#include "qapi/error.h"
+#include "qemu/error-report.h"
 #include "trace.h"
 
 #include "standard-headers/linux/virtio_ids.h"
@@ -30,19 +32,234 @@ 
 #include "hw/virtio/virtio-bus.h"
 #include "hw/virtio/virtio-access.h"
 #include "hw/virtio/virtio-iommu.h"
+#include "hw/pci/pci_bus.h"
+#include "hw/pci/pci.h"
 
 /* Max size */
 #define VIOMMU_DEFAULT_QUEUE_SIZE 256
 
+typedef struct VirtIOIOMMUDomain {
+    uint32_t id;
+    GTree *mappings;
+    QLIST_HEAD(, VirtIOIOMMUEndpoint) endpoint_list;
+} VirtIOIOMMUDomain;
+
+typedef struct VirtIOIOMMUEndpoint {
+    uint32_t id;
+    VirtIOIOMMUDomain *domain;
+    QLIST_ENTRY(VirtIOIOMMUEndpoint) next;
+} VirtIOIOMMUEndpoint;
+
+typedef struct VirtIOIOMMUInterval {
+    uint64_t low;
+    uint64_t high;
+} VirtIOIOMMUInterval;
+
+static inline uint16_t virtio_iommu_get_bdf(IOMMUDevice *dev)
+{
+    return PCI_BUILD_BDF(pci_bus_num(dev->bus), dev->devfn);
+}
+
+/**
+ * The bus number is used for lookup when SID based operations occur.
+ * In that case we lazily populate the IOMMUPciBus array from the bus hash
+ * table. At the time the IOMMUPciBus is created (iommu_find_add_as), the bus
+ * numbers may not be always initialized yet.
+ */
+static IOMMUPciBus *iommu_find_iommu_pcibus(VirtIOIOMMU *s, uint8_t bus_num)
+{
+    IOMMUPciBus *iommu_pci_bus = s->iommu_pcibus_by_bus_num[bus_num];
+
+    if (!iommu_pci_bus) {
+        GHashTableIter iter;
+
+        g_hash_table_iter_init(&iter, s->as_by_busptr);
+        while (g_hash_table_iter_next(&iter, NULL, (void **)&iommu_pci_bus)) {
+            if (pci_bus_num(iommu_pci_bus->bus) == bus_num) {
+                s->iommu_pcibus_by_bus_num[bus_num] = iommu_pci_bus;
+                return iommu_pci_bus;
+            }
+        }
+        return NULL;
+    }
+    return iommu_pci_bus;
+}
+
+static IOMMUMemoryRegion *virtio_iommu_mr(VirtIOIOMMU *s, uint32_t sid)
+{
+    uint8_t bus_n, devfn;
+    IOMMUPciBus *iommu_pci_bus;
+    IOMMUDevice *dev;
+
+    bus_n = PCI_BUS_NUM(sid);
+    iommu_pci_bus = iommu_find_iommu_pcibus(s, bus_n);
+    if (iommu_pci_bus) {
+        devfn = sid & PCI_DEVFN_MAX;
+        dev = iommu_pci_bus->pbdev[devfn];
+        if (dev) {
+            return &dev->iommu_mr;
+        }
+    }
+    return NULL;
+}
+
+static gint interval_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
+{
+    VirtIOIOMMUInterval *inta = (VirtIOIOMMUInterval *)a;
+    VirtIOIOMMUInterval *intb = (VirtIOIOMMUInterval *)b;
+
+    if (inta->high < intb->low) {
+        return -1;
+    } else if (intb->high < inta->low) {
+        return 1;
+    } else {
+        return 0;
+    }
+}
+
+static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
+{
+    QLIST_REMOVE(ep, next);
+    g_tree_unref(ep->domain->mappings);
+    ep->domain = NULL;
+}
+
+static VirtIOIOMMUEndpoint *virtio_iommu_get_endpoint(VirtIOIOMMU *s,
+                                                      uint32_t ep_id)
+{
+    VirtIOIOMMUEndpoint *ep;
+
+    ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
+    if (ep) {
+        return ep;
+    }
+    if (!virtio_iommu_mr(s, ep_id)) {
+        return NULL;
+    }
+    ep = g_malloc0(sizeof(*ep));
+    ep->id = ep_id;
+    trace_virtio_iommu_get_endpoint(ep_id);
+    g_tree_insert(s->endpoints, GUINT_TO_POINTER(ep_id), ep);
+    return ep;
+}
+
+static void virtio_iommu_put_endpoint(gpointer data)
+{
+    VirtIOIOMMUEndpoint *ep = (VirtIOIOMMUEndpoint *)data;
+
+    assert(!ep->domain);
+
+    trace_virtio_iommu_put_endpoint(ep->id);
+    g_free(ep);
+}
+
+static VirtIOIOMMUDomain *virtio_iommu_get_domain(VirtIOIOMMU *s,
+                                                  uint32_t domain_id)
+{
+    VirtIOIOMMUDomain *domain;
+
+    domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
+    if (domain) {
+        return domain;
+    }
+    domain = g_malloc0(sizeof(*domain));
+    domain->id = domain_id;
+    domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
+                                   NULL, (GDestroyNotify)g_free,
+                                   (GDestroyNotify)g_free);
+    g_tree_insert(s->domains, GUINT_TO_POINTER(domain_id), domain);
+    QLIST_INIT(&domain->endpoint_list);
+    trace_virtio_iommu_get_domain(domain_id);
+    return domain;
+}
+
+static void virtio_iommu_put_domain(gpointer data)
+{
+    VirtIOIOMMUDomain *domain = (VirtIOIOMMUDomain *)data;
+    VirtIOIOMMUEndpoint *iter, *tmp;
+
+    QLIST_FOREACH_SAFE(iter, &domain->endpoint_list, next, tmp) {
+        virtio_iommu_detach_endpoint_from_domain(iter);
+    }
+    trace_virtio_iommu_put_domain(domain->id);
+    g_free(domain);
+}
+
+static AddressSpace *virtio_iommu_find_add_as(PCIBus *bus, void *opaque,
+                                              int devfn)
+{
+    VirtIOIOMMU *s = opaque;
+    IOMMUPciBus *sbus = g_hash_table_lookup(s->as_by_busptr, bus);
+    static uint32_t mr_index;
+    IOMMUDevice *sdev;
+
+    if (!sbus) {
+        sbus = g_malloc0(sizeof(IOMMUPciBus) +
+                         sizeof(IOMMUDevice *) * PCI_DEVFN_MAX);
+        sbus->bus = bus;
+        g_hash_table_insert(s->as_by_busptr, bus, sbus);
+    }
+
+    sdev = sbus->pbdev[devfn];
+    if (!sdev) {
+        char *name = g_strdup_printf("%s-%d-%d",
+                                     TYPE_VIRTIO_IOMMU_MEMORY_REGION,
+                                     mr_index++, devfn);
+        sdev = sbus->pbdev[devfn] = g_malloc0(sizeof(IOMMUDevice));
+
+        sdev->viommu = s;
+        sdev->bus = bus;
+        sdev->devfn = devfn;
+
+        trace_virtio_iommu_init_iommu_mr(name);
+
+        memory_region_init_iommu(&sdev->iommu_mr, sizeof(sdev->iommu_mr),
+                                 TYPE_VIRTIO_IOMMU_MEMORY_REGION,
+                                 OBJECT(s), name,
+                                 UINT64_MAX);
+        address_space_init(&sdev->as,
+                           MEMORY_REGION(&sdev->iommu_mr), TYPE_VIRTIO_IOMMU);
+        g_free(name);
+    }
+    return &sdev->as;
+}
+
 static int virtio_iommu_attach(VirtIOIOMMU *s,
                                struct virtio_iommu_req_attach *req)
 {
     uint32_t domain_id = le32_to_cpu(req->domain);
     uint32_t ep_id = le32_to_cpu(req->endpoint);
+    VirtIOIOMMUDomain *domain;
+    VirtIOIOMMUEndpoint *ep;
 
     trace_virtio_iommu_attach(domain_id, ep_id);
 
-    return VIRTIO_IOMMU_S_UNSUPP;
+    ep = virtio_iommu_get_endpoint(s, ep_id);
+    if (!ep) {
+        return VIRTIO_IOMMU_S_NOENT;
+    }
+
+    if (ep->domain) {
+        VirtIOIOMMUDomain *previous_domain = ep->domain;
+        /*
+         * the device is already attached to a domain,
+         * detach it first
+         */
+        virtio_iommu_detach_endpoint_from_domain(ep);
+        if (QLIST_EMPTY(&previous_domain->endpoint_list)) {
+            g_tree_remove(s->domains, GUINT_TO_POINTER(previous_domain->id));
+        }
+    }
+
+    domain = virtio_iommu_get_domain(s, domain_id);
+    if (!QLIST_EMPTY(&domain->endpoint_list)) {
+        g_tree_ref(domain->mappings);
+    }
+    QLIST_INSERT_HEAD(&domain->endpoint_list, ep, next);
+
+    ep->domain = domain;
+
+    return VIRTIO_IOMMU_S_OK;
 }
 
 static int virtio_iommu_detach(VirtIOIOMMU *s,
@@ -50,10 +267,34 @@  static int virtio_iommu_detach(VirtIOIOMMU *s,
 {
     uint32_t domain_id = le32_to_cpu(req->domain);
     uint32_t ep_id = le32_to_cpu(req->endpoint);
+    VirtIOIOMMUDomain *domain;
+    VirtIOIOMMUEndpoint *ep;
 
     trace_virtio_iommu_detach(domain_id, ep_id);
 
-    return VIRTIO_IOMMU_S_UNSUPP;
+    ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
+    if (!ep) {
+        return VIRTIO_IOMMU_S_NOENT;
+    }
+
+    domain = ep->domain;
+
+    if (!domain || domain->id != domain_id) {
+        return VIRTIO_IOMMU_S_INVAL;
+    }
+
+    virtio_iommu_detach_endpoint_from_domain(ep);
+
+    /*
+     * when the last EP is detached, simply remove the domain for
+     * the domain list and destroy it. Note its mappings were already
+     * freed by the ref count mechanism. Next operation involving
+     * the same domain id will re-create one domain object.
+     */
+    if (QLIST_EMPTY(&domain->endpoint_list)) {
+        g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
+    }
+    return VIRTIO_IOMMU_S_OK;
 }
 
 static int virtio_iommu_map(VirtIOIOMMU *s,
@@ -172,6 +413,27 @@  out:
     }
 }
 
+static IOMMUTLBEntry virtio_iommu_translate(IOMMUMemoryRegion *mr, hwaddr addr,
+                                            IOMMUAccessFlags flag,
+                                            int iommu_idx)
+{
+    IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
+    uint32_t sid;
+
+    IOMMUTLBEntry entry = {
+        .target_as = &address_space_memory,
+        .iova = addr,
+        .translated_addr = addr,
+        .addr_mask = ~(hwaddr)0,
+        .perm = IOMMU_NONE,
+    };
+
+    sid = virtio_iommu_get_bdf(sdev);
+
+    trace_virtio_iommu_translate(mr->parent_obj.name, sid, addr, flag);
+    return entry;
+}
+
 static void virtio_iommu_get_config(VirtIODevice *vdev, uint8_t *config_data)
 {
     VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
@@ -218,6 +480,13 @@  static const VMStateDescription vmstate_virtio_iommu_device = {
     .unmigratable = 1,
 };
 
+static gint int_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
+{
+    guint ua = GPOINTER_TO_UINT(a);
+    guint ub = GPOINTER_TO_UINT(b);
+    return (ua > ub) - (ua < ub);
+}
+
 static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
 {
     VirtIODevice *vdev = VIRTIO_DEVICE(dev);
@@ -226,6 +495,8 @@  static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
     virtio_init(vdev, "virtio-iommu", VIRTIO_ID_IOMMU,
                 sizeof(struct virtio_iommu_config));
 
+    memset(s->iommu_pcibus_by_bus_num, 0, sizeof(s->iommu_pcibus_by_bus_num));
+
     s->req_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE,
                              virtio_iommu_handle_command);
     s->event_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE, NULL);
@@ -244,18 +515,43 @@  static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MMIO);
 
     qemu_mutex_init(&s->mutex);
+
+    s->as_by_busptr = g_hash_table_new_full(NULL, NULL, NULL, g_free);
+
+    if (s->primary_bus) {
+        pci_setup_iommu(s->primary_bus, virtio_iommu_find_add_as, s);
+    } else {
+        error_setg(errp, "VIRTIO-IOMMU is not attached to any PCI bus!");
+    }
 }
 
 static void virtio_iommu_device_unrealize(DeviceState *dev, Error **errp)
 {
     VirtIODevice *vdev = VIRTIO_DEVICE(dev);
+    VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
+
+    g_tree_destroy(s->domains);
+    g_tree_destroy(s->endpoints);
 
     virtio_cleanup(vdev);
 }
 
 static void virtio_iommu_device_reset(VirtIODevice *vdev)
 {
+    VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
+
     trace_virtio_iommu_device_reset();
+
+    if (s->domains) {
+        g_tree_destroy(s->domains);
+    }
+    if (s->endpoints) {
+        g_tree_destroy(s->endpoints);
+    }
+    s->domains = g_tree_new_full((GCompareDataFunc)int_cmp,
+                                 NULL, NULL, virtio_iommu_put_domain);
+    s->endpoints = g_tree_new_full((GCompareDataFunc)int_cmp,
+                                   NULL, NULL, virtio_iommu_put_endpoint);
 }
 
 static void virtio_iommu_set_status(VirtIODevice *vdev, uint8_t status)
@@ -301,6 +597,14 @@  static void virtio_iommu_class_init(ObjectClass *klass, void *data)
     vdc->vmsd = &vmstate_virtio_iommu_device;
 }
 
+static void virtio_iommu_memory_region_class_init(ObjectClass *klass,
+                                                  void *data)
+{
+    IOMMUMemoryRegionClass *imrc = IOMMU_MEMORY_REGION_CLASS(klass);
+
+    imrc->translate = virtio_iommu_translate;
+}
+
 static const TypeInfo virtio_iommu_info = {
     .name = TYPE_VIRTIO_IOMMU,
     .parent = TYPE_VIRTIO_DEVICE,
@@ -309,9 +613,16 @@  static const TypeInfo virtio_iommu_info = {
     .class_init = virtio_iommu_class_init,
 };
 
+static const TypeInfo virtio_iommu_memory_region_info = {
+    .parent = TYPE_IOMMU_MEMORY_REGION,
+    .name = TYPE_VIRTIO_IOMMU_MEMORY_REGION,
+    .class_init = virtio_iommu_memory_region_class_init,
+};
+
 static void virtio_register_types(void)
 {
     type_register_static(&virtio_iommu_info);
+    type_register_static(&virtio_iommu_memory_region_info);
 }
 
 type_init(virtio_register_types)
diff --git a/include/hw/virtio/virtio-iommu.h b/include/hw/virtio/virtio-iommu.h
index 041f2c9390..2a2c2ecf83 100644
--- a/include/hw/virtio/virtio-iommu.h
+++ b/include/hw/virtio/virtio-iommu.h
@@ -28,6 +28,8 @@ 
 #define VIRTIO_IOMMU(obj) \
         OBJECT_CHECK(VirtIOIOMMU, (obj), TYPE_VIRTIO_IOMMU)
 
+#define TYPE_VIRTIO_IOMMU_MEMORY_REGION "virtio-iommu-memory-region"
+
 typedef struct IOMMUDevice {
     void         *viommu;
     PCIBus       *bus;
@@ -48,6 +50,7 @@  typedef struct VirtIOIOMMU {
     struct virtio_iommu_config config;
     uint64_t features;
     GHashTable *as_by_busptr;
+    IOMMUPciBus *iommu_pcibus_by_bus_num[PCI_BUS_MAX];
     PCIBus *primary_bus;
     GTree *domains;
     QemuMutex mutex;