diff mbox series

[RFC,03/12] hw/arm/smmu: Add stage to TLB

Message ID 20240325101442.1306300-4-smostafa@google.com
State New
Headers show
Series SMMUv3 nested translation support | expand

Commit Message

Mostafa Saleh March 25, 2024, 10:13 a.m. UTC
TLBs for nesting will be extended to be combined, a new index is added
"stage", with 2 valid values:
 - SMMU_STAGE_1: Meaning this translates VA to PADDR, this entry can
   be cached from fully nested configuration or from stage-1 only.
   We don't support separate cached entries (VA to IPA)

 - SMMU_STAGE_2: Meaning this translates IPA to PADDR, cached from
   stage-2 only configuration.

For TLB invalidation:
 - by VA: Invalidate TLBs tagged with SMMU_STAGE_1
 - by IPA: Invalidate TLBs tagged with SMMU_STAGE_2
 - All: Will invalidate both, this is communicated to the TLB as
   SMMU_NESTED which is (SMMU_STAGE_1 | SMMU_STAGE_2) which uses
   it as a mask.

This briefly described in the user manual (ARM IHI 0070 F.b) in
"16.2.1 Caching combined structures".

Signed-off-by: Mostafa Saleh <smostafa@google.com>
---
 hw/arm/smmu-common.c         | 27 +++++++++++++++++----------
 hw/arm/smmu-internal.h       |  2 ++
 hw/arm/smmuv3.c              |  5 +++--
 hw/arm/trace-events          |  3 ++-
 include/hw/arm/smmu-common.h |  8 ++++++--
 5 files changed, 30 insertions(+), 15 deletions(-)

Comments

Eric Auger April 2, 2024, 5:15 p.m. UTC | #1
Hi Mostafa,

On 3/25/24 11:13, Mostafa Saleh wrote:
> TLBs for nesting will be extended to be combined, a new index is added
> "stage", with 2 valid values:
>  - SMMU_STAGE_1: Meaning this translates VA to PADDR, this entry can
>    be cached from fully nested configuration or from stage-1 only.
>    We don't support separate cached entries (VA to IPA)
>
>  - SMMU_STAGE_2: Meaning this translates IPA to PADDR, cached from
>    stage-2 only configuration.
>
> For TLB invalidation:
>  - by VA: Invalidate TLBs tagged with SMMU_STAGE_1
>  - by IPA: Invalidate TLBs tagged with SMMU_STAGE_2
>  - All: Will invalidate both, this is communicated to the TLB as
>    SMMU_NESTED which is (SMMU_STAGE_1 | SMMU_STAGE_2) which uses
>    it as a mask.

I don't really get why you need this extra stage field in the key. Why
aren't the asid and vmid tags enough?

Eric
>
> This briefly described in the user manual (ARM IHI 0070 F.b) in
> "16.2.1 Caching combined structures".
>
> Signed-off-by: Mostafa Saleh <smostafa@google.com>
> ---
>  hw/arm/smmu-common.c         | 27 +++++++++++++++++----------
>  hw/arm/smmu-internal.h       |  2 ++
>  hw/arm/smmuv3.c              |  5 +++--
>  hw/arm/trace-events          |  3 ++-
>  include/hw/arm/smmu-common.h |  8 ++++++--
>  5 files changed, 30 insertions(+), 15 deletions(-)
>
> diff --git a/hw/arm/smmu-common.c b/hw/arm/smmu-common.c
> index 20630eb670..677dcf9a13 100644
> --- a/hw/arm/smmu-common.c
> +++ b/hw/arm/smmu-common.c
> @@ -38,7 +38,7 @@ static guint smmu_iotlb_key_hash(gconstpointer v)
>  
>      /* Jenkins hash */
>      a = b = c = JHASH_INITVAL + sizeof(*key);
> -    a += key->asid + key->vmid + key->level + key->tg;
> +    a += key->asid + key->vmid + key->level + key->tg + key->stage;
>      b += extract64(key->iova, 0, 32);
>      c += extract64(key->iova, 32, 32);
>  
> @@ -54,14 +54,14 @@ static gboolean smmu_iotlb_key_equal(gconstpointer v1, gconstpointer v2)
>  
>      return (k1->asid == k2->asid) && (k1->iova == k2->iova) &&
>             (k1->level == k2->level) && (k1->tg == k2->tg) &&
> -           (k1->vmid == k2->vmid);
> +           (k1->vmid == k2->vmid) && (k1->stage == k2->stage);
>  }
>  
>  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
> -                                uint8_t tg, uint8_t level)
> +                                uint8_t tg, uint8_t level, SMMUStage stage)
>  {
>      SMMUIOTLBKey key = {.asid = asid, .vmid = vmid, .iova = iova,
> -                        .tg = tg, .level = level};
> +                        .tg = tg, .level = level, .stage = stage};
>  
>      return key;
>  }
> @@ -81,7 +81,8 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
>          SMMUIOTLBKey key;
>  
>          key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid,
> -                                 iova & ~mask, tg, level);
> +                                 iova & ~mask, tg, level,
> +                                 SMMU_STAGE_TO_TLB_TAG(cfg->stage));
>          entry = g_hash_table_lookup(bs->iotlb, &key);
>          if (entry) {
>              break;
> @@ -109,15 +110,16 @@ void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *new)
>  {
>      SMMUIOTLBKey *key = g_new0(SMMUIOTLBKey, 1);
>      uint8_t tg = (new->granule - 10) / 2;
> +    SMMUStage stage_tag = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
>  
>      if (g_hash_table_size(bs->iotlb) >= SMMU_IOTLB_MAX_SIZE) {
>          smmu_iotlb_inv_all(bs);
>      }
>  
>      *key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
> -                              tg, new->level);
> +                              tg, new->level, stage_tag);
>      trace_smmu_iotlb_insert(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
> -                            tg, new->level);
> +                            tg, new->level, stage_tag);
>      g_hash_table_insert(bs->iotlb, key, new);
>  }
>  
> @@ -159,18 +161,22 @@ static gboolean smmu_hash_remove_by_asid_vmid_iova(gpointer key, gpointer value,
>      if (info->vmid >= 0 && info->vmid != SMMU_IOTLB_VMID(iotlb_key)) {
>          return false;
>      }
> +    if (!(info->stage & SMMU_IOTLB_STAGE(iotlb_key))) {
> +        return false;
> +    }
>      return ((info->iova & ~entry->addr_mask) == entry->iova) ||
>             ((entry->iova & ~info->mask) == info->iova);
>  }
>  
>  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> -                         uint8_t tg, uint64_t num_pages, uint8_t ttl)
> +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
> +                         SMMUStage stage)
>  {
>      /* if tg is not set we use 4KB range invalidation */
>      uint8_t granule = tg ? tg * 2 + 10 : 12;
>  
>      if (ttl && (num_pages == 1) && (asid >= 0)) {
> -        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl);
> +        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl, stage);
>  
>          if (g_hash_table_remove(s->iotlb, &key)) {
>              return;
> @@ -184,6 +190,7 @@ void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
>      SMMUIOTLBPageInvInfo info = {
>          .asid = asid, .iova = iova,
>          .vmid = vmid,
> +        .stage = stage,
>          .mask = (num_pages * 1 << granule) - 1};
>  
>      g_hash_table_foreach_remove(s->iotlb,
> @@ -597,7 +604,7 @@ SMMUTLBEntry *smmu_translate(SMMUState *bs, SMMUTransCfg *cfg, dma_addr_t addr,
>      if (cached_entry) {
>          if ((flag & IOMMU_WO) && !(cached_entry->entry.perm & IOMMU_WO)) {
>              info->type = SMMU_PTW_ERR_PERMISSION;
> -            info->stage = cfg->stage;
> +            info->stage = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
>              return NULL;
>          }
>          return cached_entry;
> diff --git a/hw/arm/smmu-internal.h b/hw/arm/smmu-internal.h
> index 843bebb185..6caa0ddf21 100644
> --- a/hw/arm/smmu-internal.h
> +++ b/hw/arm/smmu-internal.h
> @@ -133,12 +133,14 @@ static inline int pgd_concat_idx(int start_level, int granule_sz,
>  
>  #define SMMU_IOTLB_ASID(key) ((key).asid)
>  #define SMMU_IOTLB_VMID(key) ((key).vmid)
> +#define SMMU_IOTLB_STAGE(key) ((key).stage)
>  
>  typedef struct SMMUIOTLBPageInvInfo {
>      int asid;
>      int vmid;
>      uint64_t iova;
>      uint64_t mask;
> +    SMMUStage stage;
>  } SMMUIOTLBPageInvInfo;
>  
>  typedef struct SMMUSIDRange {
> diff --git a/hw/arm/smmuv3.c b/hw/arm/smmuv3.c
> index f081ff0cc4..b27bf297e1 100644
> --- a/hw/arm/smmuv3.c
> +++ b/hw/arm/smmuv3.c
> @@ -1087,7 +1087,7 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
>      if (!tg) {
>          trace_smmuv3_range_inval(vmid, asid, addr, tg, 1, ttl, leaf);
>          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, 1);
> -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl);
> +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl, SMMU_NESTED);
>          return;
>      }
>  
> @@ -1105,7 +1105,8 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
>          num_pages = (mask + 1) >> granule;
>          trace_smmuv3_range_inval(vmid, asid, addr, tg, num_pages, ttl, leaf);
>          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, num_pages);
> -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, num_pages, ttl);
> +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg,
> +                            num_pages, ttl, SMMU_NESTED);
>          addr += mask + 1;
>      }
>  }
> diff --git a/hw/arm/trace-events b/hw/arm/trace-events
> index cc12924a84..3000c3bf14 100644
> --- a/hw/arm/trace-events
> +++ b/hw/arm/trace-events
> @@ -14,10 +14,11 @@ smmu_iotlb_inv_all(void) "IOTLB invalidate all"
>  smmu_iotlb_inv_asid(uint16_t asid) "IOTLB invalidate asid=%d"
>  smmu_iotlb_inv_vmid(uint16_t vmid) "IOTLB invalidate vmid=%d"
>  smmu_iotlb_inv_iova(uint16_t asid, uint64_t addr) "IOTLB invalidate asid=%d addr=0x%"PRIx64
> +smmu_iotlb_inv_stage(int stage) "Stage invalidate stage=%d"
>  smmu_inv_notifiers_mr(const char *name) "iommu mr=%s"
>  smmu_iotlb_lookup_hit(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache HIT asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
>  smmu_iotlb_lookup_miss(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache MISS asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
> -smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d"
> +smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level, int stage) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d stage=%d"
>  
>  # smmuv3.c
>  smmuv3_read_mmio(uint64_t addr, uint64_t val, unsigned size, uint32_t r) "addr: 0x%"PRIx64" val:0x%"PRIx64" size: 0x%x(%d)"
> diff --git a/include/hw/arm/smmu-common.h b/include/hw/arm/smmu-common.h
> index 876e78975c..695d6d10ad 100644
> --- a/include/hw/arm/smmu-common.h
> +++ b/include/hw/arm/smmu-common.h
> @@ -37,6 +37,8 @@
>  #define VMSA_IDXMSK(isz, strd, lvl)         ((1ULL << \
>                                               VMSA_BIT_LVL(isz, strd, lvl)) - 1)
>  
> +#define SMMU_STAGE_TO_TLB_TAG(stage)        (((stage) == SMMU_NESTED) ? \
> +                                             SMMU_STAGE_1 : (stage))
>  /*
>   * Page table walk error types
>   */
> @@ -136,6 +138,7 @@ typedef struct SMMUIOTLBKey {
>      uint16_t vmid;
>      uint8_t tg;
>      uint8_t level;
> +    SMMUStage stage;
>  } SMMUIOTLBKey;
>  
>  struct SMMUState {
> @@ -203,12 +206,13 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
>                                  SMMUTransTableInfo *tt, hwaddr iova);
>  void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *entry);
>  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
> -                                uint8_t tg, uint8_t level);
> +                                uint8_t tg, uint8_t level, SMMUStage stage);
>  void smmu_iotlb_inv_all(SMMUState *s);
>  void smmu_iotlb_inv_asid(SMMUState *s, uint16_t asid);
>  void smmu_iotlb_inv_vmid(SMMUState *s, uint16_t vmid);
>  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> -                         uint8_t tg, uint64_t num_pages, uint8_t ttl);
> +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
> +                         SMMUStage stage);
>  
>  /* Unmap the range of all the notifiers registered to any IOMMU mr */
>  void smmu_inv_notifiers_all(SMMUState *s);
Mostafa Saleh April 2, 2024, 6:47 p.m. UTC | #2
Hi Eric,

On Tue, Apr 02, 2024 at 07:15:20PM +0200, Eric Auger wrote:
> Hi Mostafa,
> 
> On 3/25/24 11:13, Mostafa Saleh wrote:
> > TLBs for nesting will be extended to be combined, a new index is added
> > "stage", with 2 valid values:
> >  - SMMU_STAGE_1: Meaning this translates VA to PADDR, this entry can
> >    be cached from fully nested configuration or from stage-1 only.
> >    We don't support separate cached entries (VA to IPA)
> >
> >  - SMMU_STAGE_2: Meaning this translates IPA to PADDR, cached from
> >    stage-2 only configuration.
> >
> > For TLB invalidation:
> >  - by VA: Invalidate TLBs tagged with SMMU_STAGE_1
> >  - by IPA: Invalidate TLBs tagged with SMMU_STAGE_2
> >  - All: Will invalidate both, this is communicated to the TLB as
> >    SMMU_NESTED which is (SMMU_STAGE_1 | SMMU_STAGE_2) which uses
> >    it as a mask.
> 
> I don't really get why you need this extra stage field in the key. Why
> aren't the asid and vmid tags enough?
> 

Looking again, I think we can do it with ASID and VMID only, but that
requires some rework in the invalidation path.

With nested SMMUs, we can cache entries from:
- Stage-1 (or nested): Tagged with VMID and ASID
- Stage-2: Tagged with VMID only (ASID = -1)

That should be enough for caching/lookup, but for invalidation, we
should be able to invalidate IPAs which are cached from stage-2.

At the moment, we represent ASIDs with < 0 as a wildcard for
invalidation or stage-2 and they were mutually exclusive.

An example is:
- CMD_TLBI_NH_VAA: Invalidate stage-1 for a VMID, all ASIDs (we use ASID = -1)
- CMD_TLBI_NH_VA: Invalidate stage-1 for a VMID, an ASID  ( > 0)
- CMD_TLBI_S2_IPA: Invalidate stage-2 for a VMID (we use ASID = -1)

We need to distinguish between case 1) and 3) otherwise we over invalidate.

Similarly, CMD_TLBI_NH_ALL(invalidate all stage-1 by VMID) and
CMD_TLBI_S12_VMALL(invalidate both stages by VMID).

I guess we can add variants of these functions that operate on ASIDs
(>= 0) or (< 0) which is basically stage-1 or stage-2.

Another case I can think of which is not implemented in QEMU is
global entries, where we would like to look up entries for all ASIDs
(-1), but that’s not a problem for now.

I don’t have a strong opinion, I can try to do it this way.

Thanks,
Mostafa

> Eric
> >
> > This briefly described in the user manual (ARM IHI 0070 F.b) in
> > "16.2.1 Caching combined structures".
> >
> > Signed-off-by: Mostafa Saleh <smostafa@google.com>
> > ---
> >  hw/arm/smmu-common.c         | 27 +++++++++++++++++----------
> >  hw/arm/smmu-internal.h       |  2 ++
> >  hw/arm/smmuv3.c              |  5 +++--
> >  hw/arm/trace-events          |  3 ++-
> >  include/hw/arm/smmu-common.h |  8 ++++++--
> >  5 files changed, 30 insertions(+), 15 deletions(-)
> >
> > diff --git a/hw/arm/smmu-common.c b/hw/arm/smmu-common.c
> > index 20630eb670..677dcf9a13 100644
> > --- a/hw/arm/smmu-common.c
> > +++ b/hw/arm/smmu-common.c
> > @@ -38,7 +38,7 @@ static guint smmu_iotlb_key_hash(gconstpointer v)
> >  
> >      /* Jenkins hash */
> >      a = b = c = JHASH_INITVAL + sizeof(*key);
> > -    a += key->asid + key->vmid + key->level + key->tg;
> > +    a += key->asid + key->vmid + key->level + key->tg + key->stage;
> >      b += extract64(key->iova, 0, 32);
> >      c += extract64(key->iova, 32, 32);
> >  
> > @@ -54,14 +54,14 @@ static gboolean smmu_iotlb_key_equal(gconstpointer v1, gconstpointer v2)
> >  
> >      return (k1->asid == k2->asid) && (k1->iova == k2->iova) &&
> >             (k1->level == k2->level) && (k1->tg == k2->tg) &&
> > -           (k1->vmid == k2->vmid);
> > +           (k1->vmid == k2->vmid) && (k1->stage == k2->stage);
> >  }
> >  
> >  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
> > -                                uint8_t tg, uint8_t level)
> > +                                uint8_t tg, uint8_t level, SMMUStage stage)
> >  {
> >      SMMUIOTLBKey key = {.asid = asid, .vmid = vmid, .iova = iova,
> > -                        .tg = tg, .level = level};
> > +                        .tg = tg, .level = level, .stage = stage};
> >  
> >      return key;
> >  }
> > @@ -81,7 +81,8 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
> >          SMMUIOTLBKey key;
> >  
> >          key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid,
> > -                                 iova & ~mask, tg, level);
> > +                                 iova & ~mask, tg, level,
> > +                                 SMMU_STAGE_TO_TLB_TAG(cfg->stage));
> >          entry = g_hash_table_lookup(bs->iotlb, &key);
> >          if (entry) {
> >              break;
> > @@ -109,15 +110,16 @@ void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *new)
> >  {
> >      SMMUIOTLBKey *key = g_new0(SMMUIOTLBKey, 1);
> >      uint8_t tg = (new->granule - 10) / 2;
> > +    SMMUStage stage_tag = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
> >  
> >      if (g_hash_table_size(bs->iotlb) >= SMMU_IOTLB_MAX_SIZE) {
> >          smmu_iotlb_inv_all(bs);
> >      }
> >  
> >      *key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
> > -                              tg, new->level);
> > +                              tg, new->level, stage_tag);
> >      trace_smmu_iotlb_insert(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
> > -                            tg, new->level);
> > +                            tg, new->level, stage_tag);
> >      g_hash_table_insert(bs->iotlb, key, new);
> >  }
> >  
> > @@ -159,18 +161,22 @@ static gboolean smmu_hash_remove_by_asid_vmid_iova(gpointer key, gpointer value,
> >      if (info->vmid >= 0 && info->vmid != SMMU_IOTLB_VMID(iotlb_key)) {
> >          return false;
> >      }
> > +    if (!(info->stage & SMMU_IOTLB_STAGE(iotlb_key))) {
> > +        return false;
> > +    }
> >      return ((info->iova & ~entry->addr_mask) == entry->iova) ||
> >             ((entry->iova & ~info->mask) == info->iova);
> >  }
> >  
> >  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> > -                         uint8_t tg, uint64_t num_pages, uint8_t ttl)
> > +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
> > +                         SMMUStage stage)
> >  {
> >      /* if tg is not set we use 4KB range invalidation */
> >      uint8_t granule = tg ? tg * 2 + 10 : 12;
> >  
> >      if (ttl && (num_pages == 1) && (asid >= 0)) {
> > -        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl);
> > +        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl, stage);
> >  
> >          if (g_hash_table_remove(s->iotlb, &key)) {
> >              return;
> > @@ -184,6 +190,7 @@ void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> >      SMMUIOTLBPageInvInfo info = {
> >          .asid = asid, .iova = iova,
> >          .vmid = vmid,
> > +        .stage = stage,
> >          .mask = (num_pages * 1 << granule) - 1};
> >  
> >      g_hash_table_foreach_remove(s->iotlb,
> > @@ -597,7 +604,7 @@ SMMUTLBEntry *smmu_translate(SMMUState *bs, SMMUTransCfg *cfg, dma_addr_t addr,
> >      if (cached_entry) {
> >          if ((flag & IOMMU_WO) && !(cached_entry->entry.perm & IOMMU_WO)) {
> >              info->type = SMMU_PTW_ERR_PERMISSION;
> > -            info->stage = cfg->stage;
> > +            info->stage = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
> >              return NULL;
> >          }
> >          return cached_entry;
> > diff --git a/hw/arm/smmu-internal.h b/hw/arm/smmu-internal.h
> > index 843bebb185..6caa0ddf21 100644
> > --- a/hw/arm/smmu-internal.h
> > +++ b/hw/arm/smmu-internal.h
> > @@ -133,12 +133,14 @@ static inline int pgd_concat_idx(int start_level, int granule_sz,
> >  
> >  #define SMMU_IOTLB_ASID(key) ((key).asid)
> >  #define SMMU_IOTLB_VMID(key) ((key).vmid)
> > +#define SMMU_IOTLB_STAGE(key) ((key).stage)
> >  
> >  typedef struct SMMUIOTLBPageInvInfo {
> >      int asid;
> >      int vmid;
> >      uint64_t iova;
> >      uint64_t mask;
> > +    SMMUStage stage;
> >  } SMMUIOTLBPageInvInfo;
> >  
> >  typedef struct SMMUSIDRange {
> > diff --git a/hw/arm/smmuv3.c b/hw/arm/smmuv3.c
> > index f081ff0cc4..b27bf297e1 100644
> > --- a/hw/arm/smmuv3.c
> > +++ b/hw/arm/smmuv3.c
> > @@ -1087,7 +1087,7 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
> >      if (!tg) {
> >          trace_smmuv3_range_inval(vmid, asid, addr, tg, 1, ttl, leaf);
> >          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, 1);
> > -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl);
> > +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl, SMMU_NESTED);
> >          return;
> >      }
> >  
> > @@ -1105,7 +1105,8 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
> >          num_pages = (mask + 1) >> granule;
> >          trace_smmuv3_range_inval(vmid, asid, addr, tg, num_pages, ttl, leaf);
> >          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, num_pages);
> > -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, num_pages, ttl);
> > +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg,
> > +                            num_pages, ttl, SMMU_NESTED);
> >          addr += mask + 1;
> >      }
> >  }
> > diff --git a/hw/arm/trace-events b/hw/arm/trace-events
> > index cc12924a84..3000c3bf14 100644
> > --- a/hw/arm/trace-events
> > +++ b/hw/arm/trace-events
> > @@ -14,10 +14,11 @@ smmu_iotlb_inv_all(void) "IOTLB invalidate all"
> >  smmu_iotlb_inv_asid(uint16_t asid) "IOTLB invalidate asid=%d"
> >  smmu_iotlb_inv_vmid(uint16_t vmid) "IOTLB invalidate vmid=%d"
> >  smmu_iotlb_inv_iova(uint16_t asid, uint64_t addr) "IOTLB invalidate asid=%d addr=0x%"PRIx64
> > +smmu_iotlb_inv_stage(int stage) "Stage invalidate stage=%d"
> >  smmu_inv_notifiers_mr(const char *name) "iommu mr=%s"
> >  smmu_iotlb_lookup_hit(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache HIT asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
> >  smmu_iotlb_lookup_miss(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache MISS asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
> > -smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d"
> > +smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level, int stage) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d stage=%d"
> >  
> >  # smmuv3.c
> >  smmuv3_read_mmio(uint64_t addr, uint64_t val, unsigned size, uint32_t r) "addr: 0x%"PRIx64" val:0x%"PRIx64" size: 0x%x(%d)"
> > diff --git a/include/hw/arm/smmu-common.h b/include/hw/arm/smmu-common.h
> > index 876e78975c..695d6d10ad 100644
> > --- a/include/hw/arm/smmu-common.h
> > +++ b/include/hw/arm/smmu-common.h
> > @@ -37,6 +37,8 @@
> >  #define VMSA_IDXMSK(isz, strd, lvl)         ((1ULL << \
> >                                               VMSA_BIT_LVL(isz, strd, lvl)) - 1)
> >  
> > +#define SMMU_STAGE_TO_TLB_TAG(stage)        (((stage) == SMMU_NESTED) ? \
> > +                                             SMMU_STAGE_1 : (stage))
> >  /*
> >   * Page table walk error types
> >   */
> > @@ -136,6 +138,7 @@ typedef struct SMMUIOTLBKey {
> >      uint16_t vmid;
> >      uint8_t tg;
> >      uint8_t level;
> > +    SMMUStage stage;
> >  } SMMUIOTLBKey;
> >  
> >  struct SMMUState {
> > @@ -203,12 +206,13 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
> >                                  SMMUTransTableInfo *tt, hwaddr iova);
> >  void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *entry);
> >  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
> > -                                uint8_t tg, uint8_t level);
> > +                                uint8_t tg, uint8_t level, SMMUStage stage);
> >  void smmu_iotlb_inv_all(SMMUState *s);
> >  void smmu_iotlb_inv_asid(SMMUState *s, uint16_t asid);
> >  void smmu_iotlb_inv_vmid(SMMUState *s, uint16_t vmid);
> >  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> > -                         uint8_t tg, uint64_t num_pages, uint8_t ttl);
> > +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
> > +                         SMMUStage stage);
> >  
> >  /* Unmap the range of all the notifiers registered to any IOMMU mr */
> >  void smmu_inv_notifiers_all(SMMUState *s);
>
Eric Auger April 3, 2024, 7:22 a.m. UTC | #3
Hi Mostafa,

On 4/2/24 20:47, Mostafa Saleh wrote:
> Hi Eric,
>
> On Tue, Apr 02, 2024 at 07:15:20PM +0200, Eric Auger wrote:
>> Hi Mostafa,
>>
>> On 3/25/24 11:13, Mostafa Saleh wrote:
>>> TLBs for nesting will be extended to be combined, a new index is added
>>> "stage", with 2 valid values:
>>>  - SMMU_STAGE_1: Meaning this translates VA to PADDR, this entry can
>>>    be cached from fully nested configuration or from stage-1 only.
>>>    We don't support separate cached entries (VA to IPA)
>>>
>>>  - SMMU_STAGE_2: Meaning this translates IPA to PADDR, cached from
>>>    stage-2 only configuration.
>>>
>>> For TLB invalidation:
>>>  - by VA: Invalidate TLBs tagged with SMMU_STAGE_1
>>>  - by IPA: Invalidate TLBs tagged with SMMU_STAGE_2
>>>  - All: Will invalidate both, this is communicated to the TLB as
>>>    SMMU_NESTED which is (SMMU_STAGE_1 | SMMU_STAGE_2) which uses
>>>    it as a mask.
>> I don't really get why you need this extra stage field in the key. Why
>> aren't the asid and vmid tags enough?
>>
> Looking again, I think we can do it with ASID and VMID only, but that
> requires some rework in the invalidation path.
>
> With nested SMMUs, we can cache entries from:
> - Stage-1 (or nested): Tagged with VMID and ASID
> - Stage-2: Tagged with VMID only (ASID = -1)
>
> That should be enough for caching/lookup, but for invalidation, we
> should be able to invalidate IPAs which are cached from stage-2.
>
> At the moment, we represent ASIDs with < 0 as a wildcard for
> invalidation or stage-2 and they were mutually exclusive.
>
> An example is:
> - CMD_TLBI_NH_VAA: Invalidate stage-1 for a VMID, all ASIDs (we use ASID = -1)
> - CMD_TLBI_NH_VA: Invalidate stage-1 for a VMID, an ASID  ( > 0)
> - CMD_TLBI_S2_IPA: Invalidate stage-2 for a VMID (we use ASID = -1)
>
> We need to distinguish between case 1) and 3) otherwise we over invalidate.
OK I see your point when passing the asid param to smmuv3_range_inval()
in smmuv3_range_inval().
Well if you can have separate functions for handling S1 and S2 cases
while keeping the current key that may be interesting. It may be clearer
now we have extended support. This can also help in debugging/tracing.
>
> Similarly, CMD_TLBI_NH_ALL(invalidate all stage-1 by VMID) and
> CMD_TLBI_S12_VMALL(invalidate both stages by VMID).
>
> I guess we can add variants of these functions that operate on ASIDs
> (>= 0) or (< 0) which is basically stage-1 or stage-2.
worth to try indeed.

Thanks

Eric
>
> Another case I can think of which is not implemented in QEMU is
> global entries, where we would like to look up entries for all ASIDs
> (-1), but that’s not a problem for now.
>
> I don’t have a strong opinion, I can try to do it this way.
>
> Thanks,
> Mostafa
>
>> Eric
>>> This briefly described in the user manual (ARM IHI 0070 F.b) in
>>> "16.2.1 Caching combined structures".
>>>
>>> Signed-off-by: Mostafa Saleh <smostafa@google.com>
>>> ---
>>>  hw/arm/smmu-common.c         | 27 +++++++++++++++++----------
>>>  hw/arm/smmu-internal.h       |  2 ++
>>>  hw/arm/smmuv3.c              |  5 +++--
>>>  hw/arm/trace-events          |  3 ++-
>>>  include/hw/arm/smmu-common.h |  8 ++++++--
>>>  5 files changed, 30 insertions(+), 15 deletions(-)
>>>
>>> diff --git a/hw/arm/smmu-common.c b/hw/arm/smmu-common.c
>>> index 20630eb670..677dcf9a13 100644
>>> --- a/hw/arm/smmu-common.c
>>> +++ b/hw/arm/smmu-common.c
>>> @@ -38,7 +38,7 @@ static guint smmu_iotlb_key_hash(gconstpointer v)
>>>  
>>>      /* Jenkins hash */
>>>      a = b = c = JHASH_INITVAL + sizeof(*key);
>>> -    a += key->asid + key->vmid + key->level + key->tg;
>>> +    a += key->asid + key->vmid + key->level + key->tg + key->stage;
>>>      b += extract64(key->iova, 0, 32);
>>>      c += extract64(key->iova, 32, 32);
>>>  
>>> @@ -54,14 +54,14 @@ static gboolean smmu_iotlb_key_equal(gconstpointer v1, gconstpointer v2)
>>>  
>>>      return (k1->asid == k2->asid) && (k1->iova == k2->iova) &&
>>>             (k1->level == k2->level) && (k1->tg == k2->tg) &&
>>> -           (k1->vmid == k2->vmid);
>>> +           (k1->vmid == k2->vmid) && (k1->stage == k2->stage);
>>>  }
>>>  
>>>  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
>>> -                                uint8_t tg, uint8_t level)
>>> +                                uint8_t tg, uint8_t level, SMMUStage stage)
>>>  {
>>>      SMMUIOTLBKey key = {.asid = asid, .vmid = vmid, .iova = iova,
>>> -                        .tg = tg, .level = level};
>>> +                        .tg = tg, .level = level, .stage = stage};
>>>  
>>>      return key;
>>>  }
>>> @@ -81,7 +81,8 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
>>>          SMMUIOTLBKey key;
>>>  
>>>          key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid,
>>> -                                 iova & ~mask, tg, level);
>>> +                                 iova & ~mask, tg, level,
>>> +                                 SMMU_STAGE_TO_TLB_TAG(cfg->stage));
>>>          entry = g_hash_table_lookup(bs->iotlb, &key);
>>>          if (entry) {
>>>              break;
>>> @@ -109,15 +110,16 @@ void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *new)
>>>  {
>>>      SMMUIOTLBKey *key = g_new0(SMMUIOTLBKey, 1);
>>>      uint8_t tg = (new->granule - 10) / 2;
>>> +    SMMUStage stage_tag = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
>>>  
>>>      if (g_hash_table_size(bs->iotlb) >= SMMU_IOTLB_MAX_SIZE) {
>>>          smmu_iotlb_inv_all(bs);
>>>      }
>>>  
>>>      *key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
>>> -                              tg, new->level);
>>> +                              tg, new->level, stage_tag);
>>>      trace_smmu_iotlb_insert(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
>>> -                            tg, new->level);
>>> +                            tg, new->level, stage_tag);
>>>      g_hash_table_insert(bs->iotlb, key, new);
>>>  }
>>>  
>>> @@ -159,18 +161,22 @@ static gboolean smmu_hash_remove_by_asid_vmid_iova(gpointer key, gpointer value,
>>>      if (info->vmid >= 0 && info->vmid != SMMU_IOTLB_VMID(iotlb_key)) {
>>>          return false;
>>>      }
>>> +    if (!(info->stage & SMMU_IOTLB_STAGE(iotlb_key))) {
>>> +        return false;
>>> +    }
>>>      return ((info->iova & ~entry->addr_mask) == entry->iova) ||
>>>             ((entry->iova & ~info->mask) == info->iova);
>>>  }
>>>  
>>>  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
>>> -                         uint8_t tg, uint64_t num_pages, uint8_t ttl)
>>> +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
>>> +                         SMMUStage stage)
>>>  {
>>>      /* if tg is not set we use 4KB range invalidation */
>>>      uint8_t granule = tg ? tg * 2 + 10 : 12;
>>>  
>>>      if (ttl && (num_pages == 1) && (asid >= 0)) {
>>> -        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl);
>>> +        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl, stage);
>>>  
>>>          if (g_hash_table_remove(s->iotlb, &key)) {
>>>              return;
>>> @@ -184,6 +190,7 @@ void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
>>>      SMMUIOTLBPageInvInfo info = {
>>>          .asid = asid, .iova = iova,
>>>          .vmid = vmid,
>>> +        .stage = stage,
>>>          .mask = (num_pages * 1 << granule) - 1};
>>>  
>>>      g_hash_table_foreach_remove(s->iotlb,
>>> @@ -597,7 +604,7 @@ SMMUTLBEntry *smmu_translate(SMMUState *bs, SMMUTransCfg *cfg, dma_addr_t addr,
>>>      if (cached_entry) {
>>>          if ((flag & IOMMU_WO) && !(cached_entry->entry.perm & IOMMU_WO)) {
>>>              info->type = SMMU_PTW_ERR_PERMISSION;
>>> -            info->stage = cfg->stage;
>>> +            info->stage = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
>>>              return NULL;
>>>          }
>>>          return cached_entry;
>>> diff --git a/hw/arm/smmu-internal.h b/hw/arm/smmu-internal.h
>>> index 843bebb185..6caa0ddf21 100644
>>> --- a/hw/arm/smmu-internal.h
>>> +++ b/hw/arm/smmu-internal.h
>>> @@ -133,12 +133,14 @@ static inline int pgd_concat_idx(int start_level, int granule_sz,
>>>  
>>>  #define SMMU_IOTLB_ASID(key) ((key).asid)
>>>  #define SMMU_IOTLB_VMID(key) ((key).vmid)
>>> +#define SMMU_IOTLB_STAGE(key) ((key).stage)
>>>  
>>>  typedef struct SMMUIOTLBPageInvInfo {
>>>      int asid;
>>>      int vmid;
>>>      uint64_t iova;
>>>      uint64_t mask;
>>> +    SMMUStage stage;
>>>  } SMMUIOTLBPageInvInfo;
>>>  
>>>  typedef struct SMMUSIDRange {
>>> diff --git a/hw/arm/smmuv3.c b/hw/arm/smmuv3.c
>>> index f081ff0cc4..b27bf297e1 100644
>>> --- a/hw/arm/smmuv3.c
>>> +++ b/hw/arm/smmuv3.c
>>> @@ -1087,7 +1087,7 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
>>>      if (!tg) {
>>>          trace_smmuv3_range_inval(vmid, asid, addr, tg, 1, ttl, leaf);
>>>          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, 1);
>>> -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl);
>>> +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl, SMMU_NESTED);
>>>          return;
>>>      }
>>>  
>>> @@ -1105,7 +1105,8 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
>>>          num_pages = (mask + 1) >> granule;
>>>          trace_smmuv3_range_inval(vmid, asid, addr, tg, num_pages, ttl, leaf);
>>>          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, num_pages);
>>> -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, num_pages, ttl);
>>> +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg,
>>> +                            num_pages, ttl, SMMU_NESTED);
>>>          addr += mask + 1;
>>>      }
>>>  }
>>> diff --git a/hw/arm/trace-events b/hw/arm/trace-events
>>> index cc12924a84..3000c3bf14 100644
>>> --- a/hw/arm/trace-events
>>> +++ b/hw/arm/trace-events
>>> @@ -14,10 +14,11 @@ smmu_iotlb_inv_all(void) "IOTLB invalidate all"
>>>  smmu_iotlb_inv_asid(uint16_t asid) "IOTLB invalidate asid=%d"
>>>  smmu_iotlb_inv_vmid(uint16_t vmid) "IOTLB invalidate vmid=%d"
>>>  smmu_iotlb_inv_iova(uint16_t asid, uint64_t addr) "IOTLB invalidate asid=%d addr=0x%"PRIx64
>>> +smmu_iotlb_inv_stage(int stage) "Stage invalidate stage=%d"
>>>  smmu_inv_notifiers_mr(const char *name) "iommu mr=%s"
>>>  smmu_iotlb_lookup_hit(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache HIT asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
>>>  smmu_iotlb_lookup_miss(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache MISS asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
>>> -smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d"
>>> +smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level, int stage) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d stage=%d"
>>>  
>>>  # smmuv3.c
>>>  smmuv3_read_mmio(uint64_t addr, uint64_t val, unsigned size, uint32_t r) "addr: 0x%"PRIx64" val:0x%"PRIx64" size: 0x%x(%d)"
>>> diff --git a/include/hw/arm/smmu-common.h b/include/hw/arm/smmu-common.h
>>> index 876e78975c..695d6d10ad 100644
>>> --- a/include/hw/arm/smmu-common.h
>>> +++ b/include/hw/arm/smmu-common.h
>>> @@ -37,6 +37,8 @@
>>>  #define VMSA_IDXMSK(isz, strd, lvl)         ((1ULL << \
>>>                                               VMSA_BIT_LVL(isz, strd, lvl)) - 1)
>>>  
>>> +#define SMMU_STAGE_TO_TLB_TAG(stage)        (((stage) == SMMU_NESTED) ? \
>>> +                                             SMMU_STAGE_1 : (stage))
>>>  /*
>>>   * Page table walk error types
>>>   */
>>> @@ -136,6 +138,7 @@ typedef struct SMMUIOTLBKey {
>>>      uint16_t vmid;
>>>      uint8_t tg;
>>>      uint8_t level;
>>> +    SMMUStage stage;
>>>  } SMMUIOTLBKey;
>>>  
>>>  struct SMMUState {
>>> @@ -203,12 +206,13 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
>>>                                  SMMUTransTableInfo *tt, hwaddr iova);
>>>  void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *entry);
>>>  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
>>> -                                uint8_t tg, uint8_t level);
>>> +                                uint8_t tg, uint8_t level, SMMUStage stage);
>>>  void smmu_iotlb_inv_all(SMMUState *s);
>>>  void smmu_iotlb_inv_asid(SMMUState *s, uint16_t asid);
>>>  void smmu_iotlb_inv_vmid(SMMUState *s, uint16_t vmid);
>>>  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
>>> -                         uint8_t tg, uint64_t num_pages, uint8_t ttl);
>>> +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
>>> +                         SMMUStage stage);
>>>  
>>>  /* Unmap the range of all the notifiers registered to any IOMMU mr */
>>>  void smmu_inv_notifiers_all(SMMUState *s);
Mostafa Saleh April 3, 2024, 9:55 a.m. UTC | #4
On Wed, Apr 03, 2024 at 09:22:03AM +0200, Eric Auger wrote:
> Hi Mostafa,
> 
> On 4/2/24 20:47, Mostafa Saleh wrote:
> > Hi Eric,
> >
> > On Tue, Apr 02, 2024 at 07:15:20PM +0200, Eric Auger wrote:
> >> Hi Mostafa,
> >>
> >> On 3/25/24 11:13, Mostafa Saleh wrote:
> >>> TLBs for nesting will be extended to be combined, a new index is added
> >>> "stage", with 2 valid values:
> >>>  - SMMU_STAGE_1: Meaning this translates VA to PADDR, this entry can
> >>>    be cached from fully nested configuration or from stage-1 only.
> >>>    We don't support separate cached entries (VA to IPA)
> >>>
> >>>  - SMMU_STAGE_2: Meaning this translates IPA to PADDR, cached from
> >>>    stage-2 only configuration.
> >>>
> >>> For TLB invalidation:
> >>>  - by VA: Invalidate TLBs tagged with SMMU_STAGE_1
> >>>  - by IPA: Invalidate TLBs tagged with SMMU_STAGE_2
> >>>  - All: Will invalidate both, this is communicated to the TLB as
> >>>    SMMU_NESTED which is (SMMU_STAGE_1 | SMMU_STAGE_2) which uses
> >>>    it as a mask.
> >> I don't really get why you need this extra stage field in the key. Why
> >> aren't the asid and vmid tags enough?
> >>
> > Looking again, I think we can do it with ASID and VMID only, but that
> > requires some rework in the invalidation path.
> >
> > With nested SMMUs, we can cache entries from:
> > - Stage-1 (or nested): Tagged with VMID and ASID
> > - Stage-2: Tagged with VMID only (ASID = -1)
> >
> > That should be enough for caching/lookup, but for invalidation, we
> > should be able to invalidate IPAs which are cached from stage-2.
> >
> > At the moment, we represent ASIDs with < 0 as a wildcard for
> > invalidation or stage-2 and they were mutually exclusive.
> >
> > An example is:
> > - CMD_TLBI_NH_VAA: Invalidate stage-1 for a VMID, all ASIDs (we use ASID = -1)
> > - CMD_TLBI_NH_VA: Invalidate stage-1 for a VMID, an ASID  ( > 0)
> > - CMD_TLBI_S2_IPA: Invalidate stage-2 for a VMID (we use ASID = -1)
> >
> > We need to distinguish between case 1) and 3) otherwise we over invalidate.
> OK I see your point when passing the asid param to smmuv3_range_inval()
> in smmuv3_range_inval().
> Well if you can have separate functions for handling S1 and S2 cases
> while keeping the current key that may be interesting. It may be clearer
> now we have extended support. This can also help in debugging/tracing.
> >
> > Similarly, CMD_TLBI_NH_ALL(invalidate all stage-1 by VMID) and
> > CMD_TLBI_S12_VMALL(invalidate both stages by VMID).
> >
> > I guess we can add variants of these functions that operate on ASIDs
> > (>= 0) or (< 0) which is basically stage-1 or stage-2.
> worth to try indeed.

I will switch to that in V2.

Thanks,
Mostafa

> 
> Thanks
> 
> Eric
> >
> > Another case I can think of which is not implemented in QEMU is
> > global entries, where we would like to look up entries for all ASIDs
> > (-1), but that’s not a problem for now.
> >
> > I don’t have a strong opinion, I can try to do it this way.
> >
> > Thanks,
> > Mostafa
> >
> >> Eric
> >>> This briefly described in the user manual (ARM IHI 0070 F.b) in
> >>> "16.2.1 Caching combined structures".
> >>>
> >>> Signed-off-by: Mostafa Saleh <smostafa@google.com>
> >>> ---
> >>>  hw/arm/smmu-common.c         | 27 +++++++++++++++++----------
> >>>  hw/arm/smmu-internal.h       |  2 ++
> >>>  hw/arm/smmuv3.c              |  5 +++--
> >>>  hw/arm/trace-events          |  3 ++-
> >>>  include/hw/arm/smmu-common.h |  8 ++++++--
> >>>  5 files changed, 30 insertions(+), 15 deletions(-)
> >>>
> >>> diff --git a/hw/arm/smmu-common.c b/hw/arm/smmu-common.c
> >>> index 20630eb670..677dcf9a13 100644
> >>> --- a/hw/arm/smmu-common.c
> >>> +++ b/hw/arm/smmu-common.c
> >>> @@ -38,7 +38,7 @@ static guint smmu_iotlb_key_hash(gconstpointer v)
> >>>  
> >>>      /* Jenkins hash */
> >>>      a = b = c = JHASH_INITVAL + sizeof(*key);
> >>> -    a += key->asid + key->vmid + key->level + key->tg;
> >>> +    a += key->asid + key->vmid + key->level + key->tg + key->stage;
> >>>      b += extract64(key->iova, 0, 32);
> >>>      c += extract64(key->iova, 32, 32);
> >>>  
> >>> @@ -54,14 +54,14 @@ static gboolean smmu_iotlb_key_equal(gconstpointer v1, gconstpointer v2)
> >>>  
> >>>      return (k1->asid == k2->asid) && (k1->iova == k2->iova) &&
> >>>             (k1->level == k2->level) && (k1->tg == k2->tg) &&
> >>> -           (k1->vmid == k2->vmid);
> >>> +           (k1->vmid == k2->vmid) && (k1->stage == k2->stage);
> >>>  }
> >>>  
> >>>  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
> >>> -                                uint8_t tg, uint8_t level)
> >>> +                                uint8_t tg, uint8_t level, SMMUStage stage)
> >>>  {
> >>>      SMMUIOTLBKey key = {.asid = asid, .vmid = vmid, .iova = iova,
> >>> -                        .tg = tg, .level = level};
> >>> +                        .tg = tg, .level = level, .stage = stage};
> >>>  
> >>>      return key;
> >>>  }
> >>> @@ -81,7 +81,8 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
> >>>          SMMUIOTLBKey key;
> >>>  
> >>>          key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid,
> >>> -                                 iova & ~mask, tg, level);
> >>> +                                 iova & ~mask, tg, level,
> >>> +                                 SMMU_STAGE_TO_TLB_TAG(cfg->stage));
> >>>          entry = g_hash_table_lookup(bs->iotlb, &key);
> >>>          if (entry) {
> >>>              break;
> >>> @@ -109,15 +110,16 @@ void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *new)
> >>>  {
> >>>      SMMUIOTLBKey *key = g_new0(SMMUIOTLBKey, 1);
> >>>      uint8_t tg = (new->granule - 10) / 2;
> >>> +    SMMUStage stage_tag = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
> >>>  
> >>>      if (g_hash_table_size(bs->iotlb) >= SMMU_IOTLB_MAX_SIZE) {
> >>>          smmu_iotlb_inv_all(bs);
> >>>      }
> >>>  
> >>>      *key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
> >>> -                              tg, new->level);
> >>> +                              tg, new->level, stage_tag);
> >>>      trace_smmu_iotlb_insert(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
> >>> -                            tg, new->level);
> >>> +                            tg, new->level, stage_tag);
> >>>      g_hash_table_insert(bs->iotlb, key, new);
> >>>  }
> >>>  
> >>> @@ -159,18 +161,22 @@ static gboolean smmu_hash_remove_by_asid_vmid_iova(gpointer key, gpointer value,
> >>>      if (info->vmid >= 0 && info->vmid != SMMU_IOTLB_VMID(iotlb_key)) {
> >>>          return false;
> >>>      }
> >>> +    if (!(info->stage & SMMU_IOTLB_STAGE(iotlb_key))) {
> >>> +        return false;
> >>> +    }
> >>>      return ((info->iova & ~entry->addr_mask) == entry->iova) ||
> >>>             ((entry->iova & ~info->mask) == info->iova);
> >>>  }
> >>>  
> >>>  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> >>> -                         uint8_t tg, uint64_t num_pages, uint8_t ttl)
> >>> +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
> >>> +                         SMMUStage stage)
> >>>  {
> >>>      /* if tg is not set we use 4KB range invalidation */
> >>>      uint8_t granule = tg ? tg * 2 + 10 : 12;
> >>>  
> >>>      if (ttl && (num_pages == 1) && (asid >= 0)) {
> >>> -        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl);
> >>> +        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl, stage);
> >>>  
> >>>          if (g_hash_table_remove(s->iotlb, &key)) {
> >>>              return;
> >>> @@ -184,6 +190,7 @@ void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> >>>      SMMUIOTLBPageInvInfo info = {
> >>>          .asid = asid, .iova = iova,
> >>>          .vmid = vmid,
> >>> +        .stage = stage,
> >>>          .mask = (num_pages * 1 << granule) - 1};
> >>>  
> >>>      g_hash_table_foreach_remove(s->iotlb,
> >>> @@ -597,7 +604,7 @@ SMMUTLBEntry *smmu_translate(SMMUState *bs, SMMUTransCfg *cfg, dma_addr_t addr,
> >>>      if (cached_entry) {
> >>>          if ((flag & IOMMU_WO) && !(cached_entry->entry.perm & IOMMU_WO)) {
> >>>              info->type = SMMU_PTW_ERR_PERMISSION;
> >>> -            info->stage = cfg->stage;
> >>> +            info->stage = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
> >>>              return NULL;
> >>>          }
> >>>          return cached_entry;
> >>> diff --git a/hw/arm/smmu-internal.h b/hw/arm/smmu-internal.h
> >>> index 843bebb185..6caa0ddf21 100644
> >>> --- a/hw/arm/smmu-internal.h
> >>> +++ b/hw/arm/smmu-internal.h
> >>> @@ -133,12 +133,14 @@ static inline int pgd_concat_idx(int start_level, int granule_sz,
> >>>  
> >>>  #define SMMU_IOTLB_ASID(key) ((key).asid)
> >>>  #define SMMU_IOTLB_VMID(key) ((key).vmid)
> >>> +#define SMMU_IOTLB_STAGE(key) ((key).stage)
> >>>  
> >>>  typedef struct SMMUIOTLBPageInvInfo {
> >>>      int asid;
> >>>      int vmid;
> >>>      uint64_t iova;
> >>>      uint64_t mask;
> >>> +    SMMUStage stage;
> >>>  } SMMUIOTLBPageInvInfo;
> >>>  
> >>>  typedef struct SMMUSIDRange {
> >>> diff --git a/hw/arm/smmuv3.c b/hw/arm/smmuv3.c
> >>> index f081ff0cc4..b27bf297e1 100644
> >>> --- a/hw/arm/smmuv3.c
> >>> +++ b/hw/arm/smmuv3.c
> >>> @@ -1087,7 +1087,7 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
> >>>      if (!tg) {
> >>>          trace_smmuv3_range_inval(vmid, asid, addr, tg, 1, ttl, leaf);
> >>>          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, 1);
> >>> -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl);
> >>> +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl, SMMU_NESTED);
> >>>          return;
> >>>      }
> >>>  
> >>> @@ -1105,7 +1105,8 @@ static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
> >>>          num_pages = (mask + 1) >> granule;
> >>>          trace_smmuv3_range_inval(vmid, asid, addr, tg, num_pages, ttl, leaf);
> >>>          smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, num_pages);
> >>> -        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, num_pages, ttl);
> >>> +        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg,
> >>> +                            num_pages, ttl, SMMU_NESTED);
> >>>          addr += mask + 1;
> >>>      }
> >>>  }
> >>> diff --git a/hw/arm/trace-events b/hw/arm/trace-events
> >>> index cc12924a84..3000c3bf14 100644
> >>> --- a/hw/arm/trace-events
> >>> +++ b/hw/arm/trace-events
> >>> @@ -14,10 +14,11 @@ smmu_iotlb_inv_all(void) "IOTLB invalidate all"
> >>>  smmu_iotlb_inv_asid(uint16_t asid) "IOTLB invalidate asid=%d"
> >>>  smmu_iotlb_inv_vmid(uint16_t vmid) "IOTLB invalidate vmid=%d"
> >>>  smmu_iotlb_inv_iova(uint16_t asid, uint64_t addr) "IOTLB invalidate asid=%d addr=0x%"PRIx64
> >>> +smmu_iotlb_inv_stage(int stage) "Stage invalidate stage=%d"
> >>>  smmu_inv_notifiers_mr(const char *name) "iommu mr=%s"
> >>>  smmu_iotlb_lookup_hit(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache HIT asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
> >>>  smmu_iotlb_lookup_miss(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache MISS asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
> >>> -smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d"
> >>> +smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level, int stage) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d stage=%d"
> >>>  
> >>>  # smmuv3.c
> >>>  smmuv3_read_mmio(uint64_t addr, uint64_t val, unsigned size, uint32_t r) "addr: 0x%"PRIx64" val:0x%"PRIx64" size: 0x%x(%d)"
> >>> diff --git a/include/hw/arm/smmu-common.h b/include/hw/arm/smmu-common.h
> >>> index 876e78975c..695d6d10ad 100644
> >>> --- a/include/hw/arm/smmu-common.h
> >>> +++ b/include/hw/arm/smmu-common.h
> >>> @@ -37,6 +37,8 @@
> >>>  #define VMSA_IDXMSK(isz, strd, lvl)         ((1ULL << \
> >>>                                               VMSA_BIT_LVL(isz, strd, lvl)) - 1)
> >>>  
> >>> +#define SMMU_STAGE_TO_TLB_TAG(stage)        (((stage) == SMMU_NESTED) ? \
> >>> +                                             SMMU_STAGE_1 : (stage))
> >>>  /*
> >>>   * Page table walk error types
> >>>   */
> >>> @@ -136,6 +138,7 @@ typedef struct SMMUIOTLBKey {
> >>>      uint16_t vmid;
> >>>      uint8_t tg;
> >>>      uint8_t level;
> >>> +    SMMUStage stage;
> >>>  } SMMUIOTLBKey;
> >>>  
> >>>  struct SMMUState {
> >>> @@ -203,12 +206,13 @@ SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
> >>>                                  SMMUTransTableInfo *tt, hwaddr iova);
> >>>  void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *entry);
> >>>  SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
> >>> -                                uint8_t tg, uint8_t level);
> >>> +                                uint8_t tg, uint8_t level, SMMUStage stage);
> >>>  void smmu_iotlb_inv_all(SMMUState *s);
> >>>  void smmu_iotlb_inv_asid(SMMUState *s, uint16_t asid);
> >>>  void smmu_iotlb_inv_vmid(SMMUState *s, uint16_t vmid);
> >>>  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
> >>> -                         uint8_t tg, uint64_t num_pages, uint8_t ttl);
> >>> +                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
> >>> +                         SMMUStage stage);
> >>>  
> >>>  /* Unmap the range of all the notifiers registered to any IOMMU mr */
> >>>  void smmu_inv_notifiers_all(SMMUState *s);
>
diff mbox series

Patch

diff --git a/hw/arm/smmu-common.c b/hw/arm/smmu-common.c
index 20630eb670..677dcf9a13 100644
--- a/hw/arm/smmu-common.c
+++ b/hw/arm/smmu-common.c
@@ -38,7 +38,7 @@  static guint smmu_iotlb_key_hash(gconstpointer v)
 
     /* Jenkins hash */
     a = b = c = JHASH_INITVAL + sizeof(*key);
-    a += key->asid + key->vmid + key->level + key->tg;
+    a += key->asid + key->vmid + key->level + key->tg + key->stage;
     b += extract64(key->iova, 0, 32);
     c += extract64(key->iova, 32, 32);
 
@@ -54,14 +54,14 @@  static gboolean smmu_iotlb_key_equal(gconstpointer v1, gconstpointer v2)
 
     return (k1->asid == k2->asid) && (k1->iova == k2->iova) &&
            (k1->level == k2->level) && (k1->tg == k2->tg) &&
-           (k1->vmid == k2->vmid);
+           (k1->vmid == k2->vmid) && (k1->stage == k2->stage);
 }
 
 SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
-                                uint8_t tg, uint8_t level)
+                                uint8_t tg, uint8_t level, SMMUStage stage)
 {
     SMMUIOTLBKey key = {.asid = asid, .vmid = vmid, .iova = iova,
-                        .tg = tg, .level = level};
+                        .tg = tg, .level = level, .stage = stage};
 
     return key;
 }
@@ -81,7 +81,8 @@  SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
         SMMUIOTLBKey key;
 
         key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid,
-                                 iova & ~mask, tg, level);
+                                 iova & ~mask, tg, level,
+                                 SMMU_STAGE_TO_TLB_TAG(cfg->stage));
         entry = g_hash_table_lookup(bs->iotlb, &key);
         if (entry) {
             break;
@@ -109,15 +110,16 @@  void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *new)
 {
     SMMUIOTLBKey *key = g_new0(SMMUIOTLBKey, 1);
     uint8_t tg = (new->granule - 10) / 2;
+    SMMUStage stage_tag = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
 
     if (g_hash_table_size(bs->iotlb) >= SMMU_IOTLB_MAX_SIZE) {
         smmu_iotlb_inv_all(bs);
     }
 
     *key = smmu_get_iotlb_key(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
-                              tg, new->level);
+                              tg, new->level, stage_tag);
     trace_smmu_iotlb_insert(cfg->asid, cfg->s2cfg.vmid, new->entry.iova,
-                            tg, new->level);
+                            tg, new->level, stage_tag);
     g_hash_table_insert(bs->iotlb, key, new);
 }
 
@@ -159,18 +161,22 @@  static gboolean smmu_hash_remove_by_asid_vmid_iova(gpointer key, gpointer value,
     if (info->vmid >= 0 && info->vmid != SMMU_IOTLB_VMID(iotlb_key)) {
         return false;
     }
+    if (!(info->stage & SMMU_IOTLB_STAGE(iotlb_key))) {
+        return false;
+    }
     return ((info->iova & ~entry->addr_mask) == entry->iova) ||
            ((entry->iova & ~info->mask) == info->iova);
 }
 
 void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
-                         uint8_t tg, uint64_t num_pages, uint8_t ttl)
+                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
+                         SMMUStage stage)
 {
     /* if tg is not set we use 4KB range invalidation */
     uint8_t granule = tg ? tg * 2 + 10 : 12;
 
     if (ttl && (num_pages == 1) && (asid >= 0)) {
-        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl);
+        SMMUIOTLBKey key = smmu_get_iotlb_key(asid, vmid, iova, tg, ttl, stage);
 
         if (g_hash_table_remove(s->iotlb, &key)) {
             return;
@@ -184,6 +190,7 @@  void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
     SMMUIOTLBPageInvInfo info = {
         .asid = asid, .iova = iova,
         .vmid = vmid,
+        .stage = stage,
         .mask = (num_pages * 1 << granule) - 1};
 
     g_hash_table_foreach_remove(s->iotlb,
@@ -597,7 +604,7 @@  SMMUTLBEntry *smmu_translate(SMMUState *bs, SMMUTransCfg *cfg, dma_addr_t addr,
     if (cached_entry) {
         if ((flag & IOMMU_WO) && !(cached_entry->entry.perm & IOMMU_WO)) {
             info->type = SMMU_PTW_ERR_PERMISSION;
-            info->stage = cfg->stage;
+            info->stage = SMMU_STAGE_TO_TLB_TAG(cfg->stage);
             return NULL;
         }
         return cached_entry;
diff --git a/hw/arm/smmu-internal.h b/hw/arm/smmu-internal.h
index 843bebb185..6caa0ddf21 100644
--- a/hw/arm/smmu-internal.h
+++ b/hw/arm/smmu-internal.h
@@ -133,12 +133,14 @@  static inline int pgd_concat_idx(int start_level, int granule_sz,
 
 #define SMMU_IOTLB_ASID(key) ((key).asid)
 #define SMMU_IOTLB_VMID(key) ((key).vmid)
+#define SMMU_IOTLB_STAGE(key) ((key).stage)
 
 typedef struct SMMUIOTLBPageInvInfo {
     int asid;
     int vmid;
     uint64_t iova;
     uint64_t mask;
+    SMMUStage stage;
 } SMMUIOTLBPageInvInfo;
 
 typedef struct SMMUSIDRange {
diff --git a/hw/arm/smmuv3.c b/hw/arm/smmuv3.c
index f081ff0cc4..b27bf297e1 100644
--- a/hw/arm/smmuv3.c
+++ b/hw/arm/smmuv3.c
@@ -1087,7 +1087,7 @@  static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
     if (!tg) {
         trace_smmuv3_range_inval(vmid, asid, addr, tg, 1, ttl, leaf);
         smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, 1);
-        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl);
+        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, 1, ttl, SMMU_NESTED);
         return;
     }
 
@@ -1105,7 +1105,8 @@  static void smmuv3_range_inval(SMMUState *s, Cmd *cmd)
         num_pages = (mask + 1) >> granule;
         trace_smmuv3_range_inval(vmid, asid, addr, tg, num_pages, ttl, leaf);
         smmuv3_inv_notifiers_iova(s, asid, vmid, addr, tg, num_pages);
-        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg, num_pages, ttl);
+        smmu_iotlb_inv_iova(s, asid, vmid, addr, tg,
+                            num_pages, ttl, SMMU_NESTED);
         addr += mask + 1;
     }
 }
diff --git a/hw/arm/trace-events b/hw/arm/trace-events
index cc12924a84..3000c3bf14 100644
--- a/hw/arm/trace-events
+++ b/hw/arm/trace-events
@@ -14,10 +14,11 @@  smmu_iotlb_inv_all(void) "IOTLB invalidate all"
 smmu_iotlb_inv_asid(uint16_t asid) "IOTLB invalidate asid=%d"
 smmu_iotlb_inv_vmid(uint16_t vmid) "IOTLB invalidate vmid=%d"
 smmu_iotlb_inv_iova(uint16_t asid, uint64_t addr) "IOTLB invalidate asid=%d addr=0x%"PRIx64
+smmu_iotlb_inv_stage(int stage) "Stage invalidate stage=%d"
 smmu_inv_notifiers_mr(const char *name) "iommu mr=%s"
 smmu_iotlb_lookup_hit(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache HIT asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
 smmu_iotlb_lookup_miss(uint16_t asid, uint16_t vmid, uint64_t addr, uint32_t hit, uint32_t miss, uint32_t p) "IOTLB cache MISS asid=%d vmid=%d addr=0x%"PRIx64" hit=%d miss=%d hit rate=%d"
-smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d"
+smmu_iotlb_insert(uint16_t asid, uint16_t vmid, uint64_t addr, uint8_t tg, uint8_t level, int stage) "IOTLB ++ asid=%d vmid=%d addr=0x%"PRIx64" tg=%d level=%d stage=%d"
 
 # smmuv3.c
 smmuv3_read_mmio(uint64_t addr, uint64_t val, unsigned size, uint32_t r) "addr: 0x%"PRIx64" val:0x%"PRIx64" size: 0x%x(%d)"
diff --git a/include/hw/arm/smmu-common.h b/include/hw/arm/smmu-common.h
index 876e78975c..695d6d10ad 100644
--- a/include/hw/arm/smmu-common.h
+++ b/include/hw/arm/smmu-common.h
@@ -37,6 +37,8 @@ 
 #define VMSA_IDXMSK(isz, strd, lvl)         ((1ULL << \
                                              VMSA_BIT_LVL(isz, strd, lvl)) - 1)
 
+#define SMMU_STAGE_TO_TLB_TAG(stage)        (((stage) == SMMU_NESTED) ? \
+                                             SMMU_STAGE_1 : (stage))
 /*
  * Page table walk error types
  */
@@ -136,6 +138,7 @@  typedef struct SMMUIOTLBKey {
     uint16_t vmid;
     uint8_t tg;
     uint8_t level;
+    SMMUStage stage;
 } SMMUIOTLBKey;
 
 struct SMMUState {
@@ -203,12 +206,13 @@  SMMUTLBEntry *smmu_iotlb_lookup(SMMUState *bs, SMMUTransCfg *cfg,
                                 SMMUTransTableInfo *tt, hwaddr iova);
 void smmu_iotlb_insert(SMMUState *bs, SMMUTransCfg *cfg, SMMUTLBEntry *entry);
 SMMUIOTLBKey smmu_get_iotlb_key(uint16_t asid, uint16_t vmid, uint64_t iova,
-                                uint8_t tg, uint8_t level);
+                                uint8_t tg, uint8_t level, SMMUStage stage);
 void smmu_iotlb_inv_all(SMMUState *s);
 void smmu_iotlb_inv_asid(SMMUState *s, uint16_t asid);
 void smmu_iotlb_inv_vmid(SMMUState *s, uint16_t vmid);
 void smmu_iotlb_inv_iova(SMMUState *s, int asid, int vmid, dma_addr_t iova,
-                         uint8_t tg, uint64_t num_pages, uint8_t ttl);
+                         uint8_t tg, uint64_t num_pages, uint8_t ttl,
+                         SMMUStage stage);
 
 /* Unmap the range of all the notifiers registered to any IOMMU mr */
 void smmu_inv_notifiers_all(SMMUState *s);