diff mbox series

[v3,08/13] hw/riscv/riscv-iommu: add Address Translation Cache (IOATC)

Message ID 20240523173955.1940072-9-dbarboza@ventanamicro.com
State New
Headers show
Series riscv: QEMU RISC-V IOMMU Support | expand

Commit Message

Daniel Henrique Barboza May 23, 2024, 5:39 p.m. UTC
From: Tomasz Jeznach <tjeznach@rivosinc.com>

The RISC-V IOMMU spec predicts that the IOMMU can use translation caches
to hold entries from the DDT. This includes implementation for all cache
commands that are marked as 'not implemented'.

There are some artifacts included in the cache that predicts s-stage and
g-stage elements, although we don't support it yet. We'll introduce them
next.

Signed-off-by: Tomasz Jeznach <tjeznach@rivosinc.com>
Signed-off-by: Daniel Henrique Barboza <dbarboza@ventanamicro.com>
Reviewed-by: Frank Chang <frank.chang@sifive.com>
---
 hw/riscv/riscv-iommu.c | 189 ++++++++++++++++++++++++++++++++++++++++-
 hw/riscv/riscv-iommu.h |   2 +
 2 files changed, 187 insertions(+), 4 deletions(-)

Comments

Tomasz Jeznach June 5, 2024, 5:34 p.m. UTC | #1
Daniel,

Thank you for your upstreaming work!

I've synchronized the private branch with v3 changes, and noticed
there is an important change missing in this patchset. We need
reader-writer lock around access to GLib.HashTable as it's not
MT-safe.  Diff added below, also available on github [1] branch
riscv_iommu_v4-rc1.

[1] link: https://github.com/tjeznach/qemu/tree/riscv_iommu_v4-rc1

Thanks!
- Tomasz Jeznach

diff --git a/hw/riscv/riscv-iommu.c b/hw/riscv/riscv-iommu.c
index a27f56419a..75c5d645fc 100644
--- a/hw/riscv/riscv-iommu.c
+++ b/hw/riscv/riscv-iommu.c
@@ -991,7 +991,9 @@ static void riscv_iommu_ctx_inval(RISCVIOMMUState
*s, GHFunc func,
         .pasid = pasid,
     };
     ctx_cache = g_hash_table_ref(s->ctx_cache);
+    pthread_rwlock_wrlock(&s->ctx_lock);
     g_hash_table_foreach(ctx_cache, func, &key);
+    pthread_rwlock_unlock(&s->ctx_lock);
     g_hash_table_unref(ctx_cache);
 }

@@ -1007,26 +1009,31 @@ static RISCVIOMMUContext
*riscv_iommu_ctx(RISCVIOMMUState *s,
     };

     ctx_cache = g_hash_table_ref(s->ctx_cache);
+    pthread_rwlock_rdlock(&s->ctx_lock);
     ctx = g_hash_table_lookup(ctx_cache, &key);
+    pthread_rwlock_unlock(&s->ctx_lock);

     if (ctx && (ctx->tc & RISCV_IOMMU_DC_TC_V)) {
         *ref = ctx_cache;
         return ctx;
     }

-    if (g_hash_table_size(s->ctx_cache) >= LIMIT_CACHE_CTX) {
-        ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
-                                          g_free, NULL);
-        g_hash_table_unref(qatomic_xchg(&s->ctx_cache, ctx_cache));
-    }
-
     ctx = g_new0(RISCVIOMMUContext, 1);
     ctx->devid = devid;
     ctx->pasid = pasid;

     int fault = riscv_iommu_ctx_fetch(s, ctx);
     if (!fault) {
+        pthread_rwlock_wrlock(&s->ctx_lock);
+        if (g_hash_table_size(ctx_cache) >= LIMIT_CACHE_CTX) {
+            g_hash_table_unref(ctx_cache);
+            ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
+                                              g_free, NULL);
+            g_hash_table_ref(ctx_cache);
+            g_hash_table_unref(qatomic_xchg(&s->ctx_cache, ctx_cache));
+        }
         g_hash_table_add(ctx_cache, ctx);
+        pthread_rwlock_unlock(&s->ctx_lock);
         *ref = ctx_cache;
         return ctx;
     }
@@ -1176,12 +1183,14 @@ static void riscv_iommu_iot_update(RISCVIOMMUState *s,
         return;
     }

+    pthread_rwlock_wrlock(&s->iot_lock);
     if (g_hash_table_size(s->iot_cache) >= s->iot_limit) {
         iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
                                           g_free, NULL);
         g_hash_table_unref(qatomic_xchg(&s->iot_cache, iot_cache));
     }
     g_hash_table_add(iot_cache, iot);
+    pthread_rwlock_unlock(&s->iot_lock);
 }

 static void riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
@@ -1195,7 +1204,9 @@ static void
riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
     };

     iot_cache = g_hash_table_ref(s->iot_cache);
+    pthread_rwlock_wrlock(&s->iot_lock);
     g_hash_table_foreach(iot_cache, func, &key);
+    pthread_rwlock_unlock(&s->iot_lock);
     g_hash_table_unref(iot_cache);
 }

@@ -1227,7 +1238,9 @@ static int riscv_iommu_translate(RISCVIOMMUState
*s, RISCVIOMMUContext *ctx,
         }
     }

+    pthread_rwlock_rdlock(&s->iot_lock);
     iot = riscv_iommu_iot_lookup(ctx, iot_cache, iotlb->iova);
+    pthread_rwlock_unlock(&s->iot_lock);
     perm = iot ? iot->perm : IOMMU_NONE;
     if (perm != IOMMU_NONE) {
         iotlb->translated_addr = PPN_PHYS(iot->phys);
@@ -2085,6 +2098,8 @@ static void riscv_iommu_realize(DeviceState
*dev, Error **errp)
                                          g_free, NULL);
     s->iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
                                          g_free, NULL);
+    pthread_rwlock_init(&s->ctx_lock, NULL);
+    pthread_rwlock_init(&s->iot_lock, NULL);

     s->iommus.le_next = NULL;
     s->iommus.le_prev = NULL;
diff --git a/hw/riscv/riscv-iommu.h b/hw/riscv/riscv-iommu.h
index 26236c3cee..041b3b9e05 100644
--- a/hw/riscv/riscv-iommu.h
+++ b/hw/riscv/riscv-iommu.h
@@ -71,7 +71,9 @@ struct RISCVIOMMUState {
     MemoryRegion trap_mr;

     GHashTable *ctx_cache;          /* Device translation Context Cache */
+    pthread_rwlock_t ctx_lock;      /* Device translation Cache update lock */
     GHashTable *iot_cache;          /* IO Translated Address Cache */
+    pthread_rwlock_t iot_lock;      /* IO TLB Cache update lock */
     unsigned iot_limit;             /* IO Translation Cache size limit */

     /* MMIO Hardware Interface */


On Thu, May 23, 2024 at 10:40 AM Daniel Henrique Barboza
<dbarboza@ventanamicro.com> wrote:
>
> From: Tomasz Jeznach <tjeznach@rivosinc.com>
>
> The RISC-V IOMMU spec predicts that the IOMMU can use translation caches
> to hold entries from the DDT. This includes implementation for all cache
> commands that are marked as 'not implemented'.
>
> There are some artifacts included in the cache that predicts s-stage and
> g-stage elements, although we don't support it yet. We'll introduce them
> next.
>
> Signed-off-by: Tomasz Jeznach <tjeznach@rivosinc.com>
> Signed-off-by: Daniel Henrique Barboza <dbarboza@ventanamicro.com>
> Reviewed-by: Frank Chang <frank.chang@sifive.com>
> ---
>  hw/riscv/riscv-iommu.c | 189 ++++++++++++++++++++++++++++++++++++++++-
>  hw/riscv/riscv-iommu.h |   2 +
>  2 files changed, 187 insertions(+), 4 deletions(-)
>
> diff --git a/hw/riscv/riscv-iommu.c b/hw/riscv/riscv-iommu.c
> index 39b4ff1405..abf6ae7726 100644
> --- a/hw/riscv/riscv-iommu.c
> +++ b/hw/riscv/riscv-iommu.c
> @@ -63,6 +63,16 @@ struct RISCVIOMMUContext {
>      uint64_t msiptp;            /* MSI redirection page table pointer */
>  };
>
> +/* Address translation cache entry */
> +struct RISCVIOMMUEntry {
> +    uint64_t iova:44;           /* IOVA Page Number */
> +    uint64_t pscid:20;          /* Process Soft-Context identifier */
> +    uint64_t phys:44;           /* Physical Page Number */
> +    uint64_t gscid:16;          /* Guest Soft-Context identifier */
> +    uint64_t perm:2;            /* IOMMU_RW flags */
> +    uint64_t __rfu:2;
> +};
> +
>  /* IOMMU index for transactions without PASID specified. */
>  #define RISCV_IOMMU_NOPASID 0
>
> @@ -751,13 +761,125 @@ static AddressSpace *riscv_iommu_space(RISCVIOMMUState *s, uint32_t devid)
>      return &as->iova_as;
>  }
>
> +/* Translation Object cache support */
> +static gboolean __iot_equal(gconstpointer v1, gconstpointer v2)
> +{
> +    RISCVIOMMUEntry *t1 = (RISCVIOMMUEntry *) v1;
> +    RISCVIOMMUEntry *t2 = (RISCVIOMMUEntry *) v2;
> +    return t1->gscid == t2->gscid && t1->pscid == t2->pscid &&
> +           t1->iova == t2->iova;
> +}
> +
> +static guint __iot_hash(gconstpointer v)
> +{
> +    RISCVIOMMUEntry *t = (RISCVIOMMUEntry *) v;
> +    return (guint)t->iova;
> +}
> +
> +/* GV: 1 PSCV: 1 AV: 1 */
> +static void __iot_inval_pscid_iova(gpointer key, gpointer value, gpointer data)
> +{
> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
> +    if (iot->gscid == arg->gscid &&
> +        iot->pscid == arg->pscid &&
> +        iot->iova == arg->iova) {
> +        iot->perm = IOMMU_NONE;
> +    }
> +}
> +
> +/* GV: 1 PSCV: 1 AV: 0 */
> +static void __iot_inval_pscid(gpointer key, gpointer value, gpointer data)
> +{
> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
> +    if (iot->gscid == arg->gscid &&
> +        iot->pscid == arg->pscid) {
> +        iot->perm = IOMMU_NONE;
> +    }
> +}
> +
> +/* GV: 1 GVMA: 1 */
> +static void __iot_inval_gscid_gpa(gpointer key, gpointer value, gpointer data)
> +{
> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
> +    if (iot->gscid == arg->gscid) {
> +        /* simplified cache, no GPA matching */
> +        iot->perm = IOMMU_NONE;
> +    }
> +}
> +
> +/* GV: 1 GVMA: 0 */
> +static void __iot_inval_gscid(gpointer key, gpointer value, gpointer data)
> +{
> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
> +    if (iot->gscid == arg->gscid) {
> +        iot->perm = IOMMU_NONE;
> +    }
> +}
> +
> +/* GV: 0 */
> +static void __iot_inval_all(gpointer key, gpointer value, gpointer data)
> +{
> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
> +    iot->perm = IOMMU_NONE;
> +}
> +
> +/* caller should keep ref-count for iot_cache object */
> +static RISCVIOMMUEntry *riscv_iommu_iot_lookup(RISCVIOMMUContext *ctx,
> +    GHashTable *iot_cache, hwaddr iova)
> +{
> +    RISCVIOMMUEntry key = {
> +        .pscid = get_field(ctx->ta, RISCV_IOMMU_DC_TA_PSCID),
> +        .iova  = PPN_DOWN(iova),
> +    };
> +    return g_hash_table_lookup(iot_cache, &key);
> +}
> +
> +/* caller should keep ref-count for iot_cache object */
> +static void riscv_iommu_iot_update(RISCVIOMMUState *s,
> +    GHashTable *iot_cache, RISCVIOMMUEntry *iot)
> +{
> +    if (!s->iot_limit) {
> +        return;
> +    }
> +
> +    if (g_hash_table_size(s->iot_cache) >= s->iot_limit) {
> +        iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
> +                                          g_free, NULL);
> +        g_hash_table_unref(qatomic_xchg(&s->iot_cache, iot_cache));
> +    }
> +    g_hash_table_add(iot_cache, iot);
> +}
> +
> +static void riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
> +    uint32_t gscid, uint32_t pscid, hwaddr iova)
> +{
> +    GHashTable *iot_cache;
> +    RISCVIOMMUEntry key = {
> +        .gscid = gscid,
> +        .pscid = pscid,
> +        .iova  = PPN_DOWN(iova),
> +    };
> +
> +    iot_cache = g_hash_table_ref(s->iot_cache);
> +    g_hash_table_foreach(iot_cache, func, &key);
> +    g_hash_table_unref(iot_cache);
> +}
> +
>  static int riscv_iommu_translate(RISCVIOMMUState *s, RISCVIOMMUContext *ctx,
> -    IOMMUTLBEntry *iotlb)
> +    IOMMUTLBEntry *iotlb, bool enable_cache)
>  {
> +    RISCVIOMMUEntry *iot;
> +    IOMMUAccessFlags perm;
>      bool enable_pasid;
>      bool enable_pri;
> +    GHashTable *iot_cache;
>      int fault;
>
> +    iot_cache = g_hash_table_ref(s->iot_cache);
>      /*
>       * TC[32] is reserved for custom extensions, used here to temporarily
>       * enable automatic page-request generation for ATS queries.
> @@ -765,9 +887,36 @@ static int riscv_iommu_translate(RISCVIOMMUState *s, RISCVIOMMUContext *ctx,
>      enable_pri = (iotlb->perm == IOMMU_NONE) && (ctx->tc & BIT_ULL(32));
>      enable_pasid = (ctx->tc & RISCV_IOMMU_DC_TC_PDTV);
>
> +    iot = riscv_iommu_iot_lookup(ctx, iot_cache, iotlb->iova);
> +    perm = iot ? iot->perm : IOMMU_NONE;
> +    if (perm != IOMMU_NONE) {
> +        iotlb->translated_addr = PPN_PHYS(iot->phys);
> +        iotlb->addr_mask = ~TARGET_PAGE_MASK;
> +        iotlb->perm = perm;
> +        fault = 0;
> +        goto done;
> +    }
> +
>      /* Translate using device directory / page table information. */
>      fault = riscv_iommu_spa_fetch(s, ctx, iotlb);
>
> +    if (!fault && iotlb->target_as == &s->trap_as) {
> +        /* Do not cache trapped MSI translations */
> +        goto done;
> +    }
> +
> +    if (!fault && iotlb->translated_addr != iotlb->iova && enable_cache) {
> +        iot = g_new0(RISCVIOMMUEntry, 1);
> +        iot->iova = PPN_DOWN(iotlb->iova);
> +        iot->phys = PPN_DOWN(iotlb->translated_addr);
> +        iot->pscid = get_field(ctx->ta, RISCV_IOMMU_DC_TA_PSCID);
> +        iot->perm = iotlb->perm;
> +        riscv_iommu_iot_update(s, iot_cache, iot);
> +    }
> +
> +done:
> +    g_hash_table_unref(iot_cache);
> +
>      if (enable_pri && fault) {
>          struct riscv_iommu_pq_record pr = {0};
>          if (enable_pasid) {
> @@ -907,13 +1056,40 @@ static void riscv_iommu_process_cq_tail(RISCVIOMMUState *s)
>              if (cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_PSCV) {
>                  /* illegal command arguments IOTINVAL.GVMA & PSCV == 1 */
>                  goto cmd_ill;
> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_GV)) {
> +                /* invalidate all cache mappings */
> +                func = __iot_inval_all;
> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_AV)) {
> +                /* invalidate cache matching GSCID */
> +                func = __iot_inval_gscid;
> +            } else {
> +                /* invalidate cache matching GSCID and ADDR (GPA) */
> +                func = __iot_inval_gscid_gpa;
>              }
> -            /* translation cache not implemented yet */
> +            riscv_iommu_iot_inval(s, func,
> +                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_GSCID), 0,
> +                cmd.dword1 & TARGET_PAGE_MASK);
>              break;
>
>          case RISCV_IOMMU_CMD(RISCV_IOMMU_CMD_IOTINVAL_FUNC_VMA,
>                               RISCV_IOMMU_CMD_IOTINVAL_OPCODE):
> -            /* translation cache not implemented yet */
> +            if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_GV)) {
> +                /* invalidate all cache mappings, simplified model */
> +                func = __iot_inval_all;
> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_PSCV)) {
> +                /* invalidate cache matching GSCID, simplified model */
> +                func = __iot_inval_gscid;
> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_AV)) {
> +                /* invalidate cache matching GSCID and PSCID */
> +                func = __iot_inval_pscid;
> +            } else {
> +                /* invalidate cache matching GSCID and PSCID and ADDR (IOVA) */
> +                func = __iot_inval_pscid_iova;
> +            }
> +            riscv_iommu_iot_inval(s, func,
> +                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_GSCID),
> +                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_PSCID),
> +                cmd.dword1 & TARGET_PAGE_MASK);
>              break;
>
>          case RISCV_IOMMU_CMD(RISCV_IOMMU_CMD_IODIR_FUNC_INVAL_DDT,
> @@ -1410,6 +1586,8 @@ static void riscv_iommu_realize(DeviceState *dev, Error **errp)
>      /* Device translation context cache */
>      s->ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
>                                           g_free, NULL);
> +    s->iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
> +                                         g_free, NULL);
>
>      s->iommus.le_next = NULL;
>      s->iommus.le_prev = NULL;
> @@ -1423,6 +1601,7 @@ static void riscv_iommu_unrealize(DeviceState *dev)
>      RISCVIOMMUState *s = RISCV_IOMMU(dev);
>
>      qemu_mutex_destroy(&s->core_lock);
> +    g_hash_table_unref(s->iot_cache);
>      g_hash_table_unref(s->ctx_cache);
>  }
>
> @@ -1430,6 +1609,8 @@ static Property riscv_iommu_properties[] = {
>      DEFINE_PROP_UINT32("version", RISCVIOMMUState, version,
>          RISCV_IOMMU_SPEC_DOT_VER),
>      DEFINE_PROP_UINT32("bus", RISCVIOMMUState, bus, 0x0),
> +    DEFINE_PROP_UINT32("ioatc-limit", RISCVIOMMUState, iot_limit,
> +        LIMIT_CACHE_IOT),
>      DEFINE_PROP_BOOL("intremap", RISCVIOMMUState, enable_msi, TRUE),
>      DEFINE_PROP_BOOL("off", RISCVIOMMUState, enable_off, TRUE),
>      DEFINE_PROP_LINK("downstream-mr", RISCVIOMMUState, target_mr,
> @@ -1482,7 +1663,7 @@ static IOMMUTLBEntry riscv_iommu_memory_region_translate(
>          /* Translation disabled or invalid. */
>          iotlb.addr_mask = 0;
>          iotlb.perm = IOMMU_NONE;
> -    } else if (riscv_iommu_translate(as->iommu, ctx, &iotlb)) {
> +    } else if (riscv_iommu_translate(as->iommu, ctx, &iotlb, true)) {
>          /* Translation disabled or fault reported. */
>          iotlb.addr_mask = 0;
>          iotlb.perm = IOMMU_NONE;
> diff --git a/hw/riscv/riscv-iommu.h b/hw/riscv/riscv-iommu.h
> index 31d3907d33..3afee9f3e8 100644
> --- a/hw/riscv/riscv-iommu.h
> +++ b/hw/riscv/riscv-iommu.h
> @@ -68,6 +68,8 @@ struct RISCVIOMMUState {
>      MemoryRegion trap_mr;
>
>      GHashTable *ctx_cache;          /* Device translation Context Cache */
> +    GHashTable *iot_cache;          /* IO Translated Address Cache */
> +    unsigned iot_limit;             /* IO Translation Cache size limit */
>
>      /* MMIO Hardware Interface */
>      MemoryRegion regs_mr;
> --
> 2.44.0
>
Daniel Henrique Barboza June 7, 2024, 8:30 a.m. UTC | #2
Hi Tomasz,

On 6/5/24 2:34 PM, Tomasz Jeznach wrote:
> Daniel,
> 
> Thank you for your upstreaming work!

Glad to help!

> 
> I've synchronized the private branch with v3 changes, and noticed
> there is an important change missing in this patchset. We need
> reader-writer lock around access to GLib.HashTable as it's not
> MT-safe.  Diff added below, also available on github [1] branch
> riscv_iommu_v4-rc1.
> 
> [1] link: https://github.com/tjeznach/qemu/tree/riscv_iommu_v4-rc1


Just picked the changes and squashed them in patch 3 and 8. Thanks!


Daniel

> 
> Thanks!
> - Tomasz Jeznach
> 
> diff --git a/hw/riscv/riscv-iommu.c b/hw/riscv/riscv-iommu.c
> index a27f56419a..75c5d645fc 100644
> --- a/hw/riscv/riscv-iommu.c
> +++ b/hw/riscv/riscv-iommu.c
> @@ -991,7 +991,9 @@ static void riscv_iommu_ctx_inval(RISCVIOMMUState
> *s, GHFunc func,
>           .pasid = pasid,
>       };
>       ctx_cache = g_hash_table_ref(s->ctx_cache);
> +    pthread_rwlock_wrlock(&s->ctx_lock);
>       g_hash_table_foreach(ctx_cache, func, &key);
> +    pthread_rwlock_unlock(&s->ctx_lock);
>       g_hash_table_unref(ctx_cache);
>   }
> 
> @@ -1007,26 +1009,31 @@ static RISCVIOMMUContext
> *riscv_iommu_ctx(RISCVIOMMUState *s,
>       };
> 
>       ctx_cache = g_hash_table_ref(s->ctx_cache);
> +    pthread_rwlock_rdlock(&s->ctx_lock);
>       ctx = g_hash_table_lookup(ctx_cache, &key);
> +    pthread_rwlock_unlock(&s->ctx_lock);
> 
>       if (ctx && (ctx->tc & RISCV_IOMMU_DC_TC_V)) {
>           *ref = ctx_cache;
>           return ctx;
>       }
> 
> -    if (g_hash_table_size(s->ctx_cache) >= LIMIT_CACHE_CTX) {
> -        ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
> -                                          g_free, NULL);
> -        g_hash_table_unref(qatomic_xchg(&s->ctx_cache, ctx_cache));
> -    }
> -
>       ctx = g_new0(RISCVIOMMUContext, 1);
>       ctx->devid = devid;
>       ctx->pasid = pasid;
> 
>       int fault = riscv_iommu_ctx_fetch(s, ctx);
>       if (!fault) {
> +        pthread_rwlock_wrlock(&s->ctx_lock);
> +        if (g_hash_table_size(ctx_cache) >= LIMIT_CACHE_CTX) {
> +            g_hash_table_unref(ctx_cache);
> +            ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
> +                                              g_free, NULL);
> +            g_hash_table_ref(ctx_cache);
> +            g_hash_table_unref(qatomic_xchg(&s->ctx_cache, ctx_cache));
> +        }
>           g_hash_table_add(ctx_cache, ctx);
> +        pthread_rwlock_unlock(&s->ctx_lock);
>           *ref = ctx_cache;
>           return ctx;
>       }
> @@ -1176,12 +1183,14 @@ static void riscv_iommu_iot_update(RISCVIOMMUState *s,
>           return;
>       }
> 
> +    pthread_rwlock_wrlock(&s->iot_lock);
>       if (g_hash_table_size(s->iot_cache) >= s->iot_limit) {
>           iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
>                                             g_free, NULL);
>           g_hash_table_unref(qatomic_xchg(&s->iot_cache, iot_cache));
>       }
>       g_hash_table_add(iot_cache, iot);
> +    pthread_rwlock_unlock(&s->iot_lock);
>   }
> 
>   static void riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
> @@ -1195,7 +1204,9 @@ static void
> riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
>       };
> 
>       iot_cache = g_hash_table_ref(s->iot_cache);
> +    pthread_rwlock_wrlock(&s->iot_lock);
>       g_hash_table_foreach(iot_cache, func, &key);
> +    pthread_rwlock_unlock(&s->iot_lock);
>       g_hash_table_unref(iot_cache);
>   }
> 
> @@ -1227,7 +1238,9 @@ static int riscv_iommu_translate(RISCVIOMMUState
> *s, RISCVIOMMUContext *ctx,
>           }
>       }
> 
> +    pthread_rwlock_rdlock(&s->iot_lock);
>       iot = riscv_iommu_iot_lookup(ctx, iot_cache, iotlb->iova);
> +    pthread_rwlock_unlock(&s->iot_lock);
>       perm = iot ? iot->perm : IOMMU_NONE;
>       if (perm != IOMMU_NONE) {
>           iotlb->translated_addr = PPN_PHYS(iot->phys);
> @@ -2085,6 +2098,8 @@ static void riscv_iommu_realize(DeviceState
> *dev, Error **errp)
>                                            g_free, NULL);
>       s->iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
>                                            g_free, NULL);
> +    pthread_rwlock_init(&s->ctx_lock, NULL);
> +    pthread_rwlock_init(&s->iot_lock, NULL);
> 
>       s->iommus.le_next = NULL;
>       s->iommus.le_prev = NULL;
> diff --git a/hw/riscv/riscv-iommu.h b/hw/riscv/riscv-iommu.h
> index 26236c3cee..041b3b9e05 100644
> --- a/hw/riscv/riscv-iommu.h
> +++ b/hw/riscv/riscv-iommu.h
> @@ -71,7 +71,9 @@ struct RISCVIOMMUState {
>       MemoryRegion trap_mr;
> 
>       GHashTable *ctx_cache;          /* Device translation Context Cache */
> +    pthread_rwlock_t ctx_lock;      /* Device translation Cache update lock */
>       GHashTable *iot_cache;          /* IO Translated Address Cache */
> +    pthread_rwlock_t iot_lock;      /* IO TLB Cache update lock */
>       unsigned iot_limit;             /* IO Translation Cache size limit */
> 
>       /* MMIO Hardware Interface */
> 
> 
> On Thu, May 23, 2024 at 10:40 AM Daniel Henrique Barboza
> <dbarboza@ventanamicro.com> wrote:
>>
>> From: Tomasz Jeznach <tjeznach@rivosinc.com>
>>
>> The RISC-V IOMMU spec predicts that the IOMMU can use translation caches
>> to hold entries from the DDT. This includes implementation for all cache
>> commands that are marked as 'not implemented'.
>>
>> There are some artifacts included in the cache that predicts s-stage and
>> g-stage elements, although we don't support it yet. We'll introduce them
>> next.
>>
>> Signed-off-by: Tomasz Jeznach <tjeznach@rivosinc.com>
>> Signed-off-by: Daniel Henrique Barboza <dbarboza@ventanamicro.com>
>> Reviewed-by: Frank Chang <frank.chang@sifive.com>
>> ---
>>   hw/riscv/riscv-iommu.c | 189 ++++++++++++++++++++++++++++++++++++++++-
>>   hw/riscv/riscv-iommu.h |   2 +
>>   2 files changed, 187 insertions(+), 4 deletions(-)
>>
>> diff --git a/hw/riscv/riscv-iommu.c b/hw/riscv/riscv-iommu.c
>> index 39b4ff1405..abf6ae7726 100644
>> --- a/hw/riscv/riscv-iommu.c
>> +++ b/hw/riscv/riscv-iommu.c
>> @@ -63,6 +63,16 @@ struct RISCVIOMMUContext {
>>       uint64_t msiptp;            /* MSI redirection page table pointer */
>>   };
>>
>> +/* Address translation cache entry */
>> +struct RISCVIOMMUEntry {
>> +    uint64_t iova:44;           /* IOVA Page Number */
>> +    uint64_t pscid:20;          /* Process Soft-Context identifier */
>> +    uint64_t phys:44;           /* Physical Page Number */
>> +    uint64_t gscid:16;          /* Guest Soft-Context identifier */
>> +    uint64_t perm:2;            /* IOMMU_RW flags */
>> +    uint64_t __rfu:2;
>> +};
>> +
>>   /* IOMMU index for transactions without PASID specified. */
>>   #define RISCV_IOMMU_NOPASID 0
>>
>> @@ -751,13 +761,125 @@ static AddressSpace *riscv_iommu_space(RISCVIOMMUState *s, uint32_t devid)
>>       return &as->iova_as;
>>   }
>>
>> +/* Translation Object cache support */
>> +static gboolean __iot_equal(gconstpointer v1, gconstpointer v2)
>> +{
>> +    RISCVIOMMUEntry *t1 = (RISCVIOMMUEntry *) v1;
>> +    RISCVIOMMUEntry *t2 = (RISCVIOMMUEntry *) v2;
>> +    return t1->gscid == t2->gscid && t1->pscid == t2->pscid &&
>> +           t1->iova == t2->iova;
>> +}
>> +
>> +static guint __iot_hash(gconstpointer v)
>> +{
>> +    RISCVIOMMUEntry *t = (RISCVIOMMUEntry *) v;
>> +    return (guint)t->iova;
>> +}
>> +
>> +/* GV: 1 PSCV: 1 AV: 1 */
>> +static void __iot_inval_pscid_iova(gpointer key, gpointer value, gpointer data)
>> +{
>> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
>> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
>> +    if (iot->gscid == arg->gscid &&
>> +        iot->pscid == arg->pscid &&
>> +        iot->iova == arg->iova) {
>> +        iot->perm = IOMMU_NONE;
>> +    }
>> +}
>> +
>> +/* GV: 1 PSCV: 1 AV: 0 */
>> +static void __iot_inval_pscid(gpointer key, gpointer value, gpointer data)
>> +{
>> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
>> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
>> +    if (iot->gscid == arg->gscid &&
>> +        iot->pscid == arg->pscid) {
>> +        iot->perm = IOMMU_NONE;
>> +    }
>> +}
>> +
>> +/* GV: 1 GVMA: 1 */
>> +static void __iot_inval_gscid_gpa(gpointer key, gpointer value, gpointer data)
>> +{
>> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
>> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
>> +    if (iot->gscid == arg->gscid) {
>> +        /* simplified cache, no GPA matching */
>> +        iot->perm = IOMMU_NONE;
>> +    }
>> +}
>> +
>> +/* GV: 1 GVMA: 0 */
>> +static void __iot_inval_gscid(gpointer key, gpointer value, gpointer data)
>> +{
>> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
>> +    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
>> +    if (iot->gscid == arg->gscid) {
>> +        iot->perm = IOMMU_NONE;
>> +    }
>> +}
>> +
>> +/* GV: 0 */
>> +static void __iot_inval_all(gpointer key, gpointer value, gpointer data)
>> +{
>> +    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
>> +    iot->perm = IOMMU_NONE;
>> +}
>> +
>> +/* caller should keep ref-count for iot_cache object */
>> +static RISCVIOMMUEntry *riscv_iommu_iot_lookup(RISCVIOMMUContext *ctx,
>> +    GHashTable *iot_cache, hwaddr iova)
>> +{
>> +    RISCVIOMMUEntry key = {
>> +        .pscid = get_field(ctx->ta, RISCV_IOMMU_DC_TA_PSCID),
>> +        .iova  = PPN_DOWN(iova),
>> +    };
>> +    return g_hash_table_lookup(iot_cache, &key);
>> +}
>> +
>> +/* caller should keep ref-count for iot_cache object */
>> +static void riscv_iommu_iot_update(RISCVIOMMUState *s,
>> +    GHashTable *iot_cache, RISCVIOMMUEntry *iot)
>> +{
>> +    if (!s->iot_limit) {
>> +        return;
>> +    }
>> +
>> +    if (g_hash_table_size(s->iot_cache) >= s->iot_limit) {
>> +        iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
>> +                                          g_free, NULL);
>> +        g_hash_table_unref(qatomic_xchg(&s->iot_cache, iot_cache));
>> +    }
>> +    g_hash_table_add(iot_cache, iot);
>> +}
>> +
>> +static void riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
>> +    uint32_t gscid, uint32_t pscid, hwaddr iova)
>> +{
>> +    GHashTable *iot_cache;
>> +    RISCVIOMMUEntry key = {
>> +        .gscid = gscid,
>> +        .pscid = pscid,
>> +        .iova  = PPN_DOWN(iova),
>> +    };
>> +
>> +    iot_cache = g_hash_table_ref(s->iot_cache);
>> +    g_hash_table_foreach(iot_cache, func, &key);
>> +    g_hash_table_unref(iot_cache);
>> +}
>> +
>>   static int riscv_iommu_translate(RISCVIOMMUState *s, RISCVIOMMUContext *ctx,
>> -    IOMMUTLBEntry *iotlb)
>> +    IOMMUTLBEntry *iotlb, bool enable_cache)
>>   {
>> +    RISCVIOMMUEntry *iot;
>> +    IOMMUAccessFlags perm;
>>       bool enable_pasid;
>>       bool enable_pri;
>> +    GHashTable *iot_cache;
>>       int fault;
>>
>> +    iot_cache = g_hash_table_ref(s->iot_cache);
>>       /*
>>        * TC[32] is reserved for custom extensions, used here to temporarily
>>        * enable automatic page-request generation for ATS queries.
>> @@ -765,9 +887,36 @@ static int riscv_iommu_translate(RISCVIOMMUState *s, RISCVIOMMUContext *ctx,
>>       enable_pri = (iotlb->perm == IOMMU_NONE) && (ctx->tc & BIT_ULL(32));
>>       enable_pasid = (ctx->tc & RISCV_IOMMU_DC_TC_PDTV);
>>
>> +    iot = riscv_iommu_iot_lookup(ctx, iot_cache, iotlb->iova);
>> +    perm = iot ? iot->perm : IOMMU_NONE;
>> +    if (perm != IOMMU_NONE) {
>> +        iotlb->translated_addr = PPN_PHYS(iot->phys);
>> +        iotlb->addr_mask = ~TARGET_PAGE_MASK;
>> +        iotlb->perm = perm;
>> +        fault = 0;
>> +        goto done;
>> +    }
>> +
>>       /* Translate using device directory / page table information. */
>>       fault = riscv_iommu_spa_fetch(s, ctx, iotlb);
>>
>> +    if (!fault && iotlb->target_as == &s->trap_as) {
>> +        /* Do not cache trapped MSI translations */
>> +        goto done;
>> +    }
>> +
>> +    if (!fault && iotlb->translated_addr != iotlb->iova && enable_cache) {
>> +        iot = g_new0(RISCVIOMMUEntry, 1);
>> +        iot->iova = PPN_DOWN(iotlb->iova);
>> +        iot->phys = PPN_DOWN(iotlb->translated_addr);
>> +        iot->pscid = get_field(ctx->ta, RISCV_IOMMU_DC_TA_PSCID);
>> +        iot->perm = iotlb->perm;
>> +        riscv_iommu_iot_update(s, iot_cache, iot);
>> +    }
>> +
>> +done:
>> +    g_hash_table_unref(iot_cache);
>> +
>>       if (enable_pri && fault) {
>>           struct riscv_iommu_pq_record pr = {0};
>>           if (enable_pasid) {
>> @@ -907,13 +1056,40 @@ static void riscv_iommu_process_cq_tail(RISCVIOMMUState *s)
>>               if (cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_PSCV) {
>>                   /* illegal command arguments IOTINVAL.GVMA & PSCV == 1 */
>>                   goto cmd_ill;
>> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_GV)) {
>> +                /* invalidate all cache mappings */
>> +                func = __iot_inval_all;
>> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_AV)) {
>> +                /* invalidate cache matching GSCID */
>> +                func = __iot_inval_gscid;
>> +            } else {
>> +                /* invalidate cache matching GSCID and ADDR (GPA) */
>> +                func = __iot_inval_gscid_gpa;
>>               }
>> -            /* translation cache not implemented yet */
>> +            riscv_iommu_iot_inval(s, func,
>> +                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_GSCID), 0,
>> +                cmd.dword1 & TARGET_PAGE_MASK);
>>               break;
>>
>>           case RISCV_IOMMU_CMD(RISCV_IOMMU_CMD_IOTINVAL_FUNC_VMA,
>>                                RISCV_IOMMU_CMD_IOTINVAL_OPCODE):
>> -            /* translation cache not implemented yet */
>> +            if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_GV)) {
>> +                /* invalidate all cache mappings, simplified model */
>> +                func = __iot_inval_all;
>> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_PSCV)) {
>> +                /* invalidate cache matching GSCID, simplified model */
>> +                func = __iot_inval_gscid;
>> +            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_AV)) {
>> +                /* invalidate cache matching GSCID and PSCID */
>> +                func = __iot_inval_pscid;
>> +            } else {
>> +                /* invalidate cache matching GSCID and PSCID and ADDR (IOVA) */
>> +                func = __iot_inval_pscid_iova;
>> +            }
>> +            riscv_iommu_iot_inval(s, func,
>> +                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_GSCID),
>> +                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_PSCID),
>> +                cmd.dword1 & TARGET_PAGE_MASK);
>>               break;
>>
>>           case RISCV_IOMMU_CMD(RISCV_IOMMU_CMD_IODIR_FUNC_INVAL_DDT,
>> @@ -1410,6 +1586,8 @@ static void riscv_iommu_realize(DeviceState *dev, Error **errp)
>>       /* Device translation context cache */
>>       s->ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
>>                                            g_free, NULL);
>> +    s->iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
>> +                                         g_free, NULL);
>>
>>       s->iommus.le_next = NULL;
>>       s->iommus.le_prev = NULL;
>> @@ -1423,6 +1601,7 @@ static void riscv_iommu_unrealize(DeviceState *dev)
>>       RISCVIOMMUState *s = RISCV_IOMMU(dev);
>>
>>       qemu_mutex_destroy(&s->core_lock);
>> +    g_hash_table_unref(s->iot_cache);
>>       g_hash_table_unref(s->ctx_cache);
>>   }
>>
>> @@ -1430,6 +1609,8 @@ static Property riscv_iommu_properties[] = {
>>       DEFINE_PROP_UINT32("version", RISCVIOMMUState, version,
>>           RISCV_IOMMU_SPEC_DOT_VER),
>>       DEFINE_PROP_UINT32("bus", RISCVIOMMUState, bus, 0x0),
>> +    DEFINE_PROP_UINT32("ioatc-limit", RISCVIOMMUState, iot_limit,
>> +        LIMIT_CACHE_IOT),
>>       DEFINE_PROP_BOOL("intremap", RISCVIOMMUState, enable_msi, TRUE),
>>       DEFINE_PROP_BOOL("off", RISCVIOMMUState, enable_off, TRUE),
>>       DEFINE_PROP_LINK("downstream-mr", RISCVIOMMUState, target_mr,
>> @@ -1482,7 +1663,7 @@ static IOMMUTLBEntry riscv_iommu_memory_region_translate(
>>           /* Translation disabled or invalid. */
>>           iotlb.addr_mask = 0;
>>           iotlb.perm = IOMMU_NONE;
>> -    } else if (riscv_iommu_translate(as->iommu, ctx, &iotlb)) {
>> +    } else if (riscv_iommu_translate(as->iommu, ctx, &iotlb, true)) {
>>           /* Translation disabled or fault reported. */
>>           iotlb.addr_mask = 0;
>>           iotlb.perm = IOMMU_NONE;
>> diff --git a/hw/riscv/riscv-iommu.h b/hw/riscv/riscv-iommu.h
>> index 31d3907d33..3afee9f3e8 100644
>> --- a/hw/riscv/riscv-iommu.h
>> +++ b/hw/riscv/riscv-iommu.h
>> @@ -68,6 +68,8 @@ struct RISCVIOMMUState {
>>       MemoryRegion trap_mr;
>>
>>       GHashTable *ctx_cache;          /* Device translation Context Cache */
>> +    GHashTable *iot_cache;          /* IO Translated Address Cache */
>> +    unsigned iot_limit;             /* IO Translation Cache size limit */
>>
>>       /* MMIO Hardware Interface */
>>       MemoryRegion regs_mr;
>> --
>> 2.44.0
>>
diff mbox series

Patch

diff --git a/hw/riscv/riscv-iommu.c b/hw/riscv/riscv-iommu.c
index 39b4ff1405..abf6ae7726 100644
--- a/hw/riscv/riscv-iommu.c
+++ b/hw/riscv/riscv-iommu.c
@@ -63,6 +63,16 @@  struct RISCVIOMMUContext {
     uint64_t msiptp;            /* MSI redirection page table pointer */
 };
 
+/* Address translation cache entry */
+struct RISCVIOMMUEntry {
+    uint64_t iova:44;           /* IOVA Page Number */
+    uint64_t pscid:20;          /* Process Soft-Context identifier */
+    uint64_t phys:44;           /* Physical Page Number */
+    uint64_t gscid:16;          /* Guest Soft-Context identifier */
+    uint64_t perm:2;            /* IOMMU_RW flags */
+    uint64_t __rfu:2;
+};
+
 /* IOMMU index for transactions without PASID specified. */
 #define RISCV_IOMMU_NOPASID 0
 
@@ -751,13 +761,125 @@  static AddressSpace *riscv_iommu_space(RISCVIOMMUState *s, uint32_t devid)
     return &as->iova_as;
 }
 
+/* Translation Object cache support */
+static gboolean __iot_equal(gconstpointer v1, gconstpointer v2)
+{
+    RISCVIOMMUEntry *t1 = (RISCVIOMMUEntry *) v1;
+    RISCVIOMMUEntry *t2 = (RISCVIOMMUEntry *) v2;
+    return t1->gscid == t2->gscid && t1->pscid == t2->pscid &&
+           t1->iova == t2->iova;
+}
+
+static guint __iot_hash(gconstpointer v)
+{
+    RISCVIOMMUEntry *t = (RISCVIOMMUEntry *) v;
+    return (guint)t->iova;
+}
+
+/* GV: 1 PSCV: 1 AV: 1 */
+static void __iot_inval_pscid_iova(gpointer key, gpointer value, gpointer data)
+{
+    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
+    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
+    if (iot->gscid == arg->gscid &&
+        iot->pscid == arg->pscid &&
+        iot->iova == arg->iova) {
+        iot->perm = IOMMU_NONE;
+    }
+}
+
+/* GV: 1 PSCV: 1 AV: 0 */
+static void __iot_inval_pscid(gpointer key, gpointer value, gpointer data)
+{
+    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
+    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
+    if (iot->gscid == arg->gscid &&
+        iot->pscid == arg->pscid) {
+        iot->perm = IOMMU_NONE;
+    }
+}
+
+/* GV: 1 GVMA: 1 */
+static void __iot_inval_gscid_gpa(gpointer key, gpointer value, gpointer data)
+{
+    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
+    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
+    if (iot->gscid == arg->gscid) {
+        /* simplified cache, no GPA matching */
+        iot->perm = IOMMU_NONE;
+    }
+}
+
+/* GV: 1 GVMA: 0 */
+static void __iot_inval_gscid(gpointer key, gpointer value, gpointer data)
+{
+    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
+    RISCVIOMMUEntry *arg = (RISCVIOMMUEntry *) data;
+    if (iot->gscid == arg->gscid) {
+        iot->perm = IOMMU_NONE;
+    }
+}
+
+/* GV: 0 */
+static void __iot_inval_all(gpointer key, gpointer value, gpointer data)
+{
+    RISCVIOMMUEntry *iot = (RISCVIOMMUEntry *) value;
+    iot->perm = IOMMU_NONE;
+}
+
+/* caller should keep ref-count for iot_cache object */
+static RISCVIOMMUEntry *riscv_iommu_iot_lookup(RISCVIOMMUContext *ctx,
+    GHashTable *iot_cache, hwaddr iova)
+{
+    RISCVIOMMUEntry key = {
+        .pscid = get_field(ctx->ta, RISCV_IOMMU_DC_TA_PSCID),
+        .iova  = PPN_DOWN(iova),
+    };
+    return g_hash_table_lookup(iot_cache, &key);
+}
+
+/* caller should keep ref-count for iot_cache object */
+static void riscv_iommu_iot_update(RISCVIOMMUState *s,
+    GHashTable *iot_cache, RISCVIOMMUEntry *iot)
+{
+    if (!s->iot_limit) {
+        return;
+    }
+
+    if (g_hash_table_size(s->iot_cache) >= s->iot_limit) {
+        iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
+                                          g_free, NULL);
+        g_hash_table_unref(qatomic_xchg(&s->iot_cache, iot_cache));
+    }
+    g_hash_table_add(iot_cache, iot);
+}
+
+static void riscv_iommu_iot_inval(RISCVIOMMUState *s, GHFunc func,
+    uint32_t gscid, uint32_t pscid, hwaddr iova)
+{
+    GHashTable *iot_cache;
+    RISCVIOMMUEntry key = {
+        .gscid = gscid,
+        .pscid = pscid,
+        .iova  = PPN_DOWN(iova),
+    };
+
+    iot_cache = g_hash_table_ref(s->iot_cache);
+    g_hash_table_foreach(iot_cache, func, &key);
+    g_hash_table_unref(iot_cache);
+}
+
 static int riscv_iommu_translate(RISCVIOMMUState *s, RISCVIOMMUContext *ctx,
-    IOMMUTLBEntry *iotlb)
+    IOMMUTLBEntry *iotlb, bool enable_cache)
 {
+    RISCVIOMMUEntry *iot;
+    IOMMUAccessFlags perm;
     bool enable_pasid;
     bool enable_pri;
+    GHashTable *iot_cache;
     int fault;
 
+    iot_cache = g_hash_table_ref(s->iot_cache);
     /*
      * TC[32] is reserved for custom extensions, used here to temporarily
      * enable automatic page-request generation for ATS queries.
@@ -765,9 +887,36 @@  static int riscv_iommu_translate(RISCVIOMMUState *s, RISCVIOMMUContext *ctx,
     enable_pri = (iotlb->perm == IOMMU_NONE) && (ctx->tc & BIT_ULL(32));
     enable_pasid = (ctx->tc & RISCV_IOMMU_DC_TC_PDTV);
 
+    iot = riscv_iommu_iot_lookup(ctx, iot_cache, iotlb->iova);
+    perm = iot ? iot->perm : IOMMU_NONE;
+    if (perm != IOMMU_NONE) {
+        iotlb->translated_addr = PPN_PHYS(iot->phys);
+        iotlb->addr_mask = ~TARGET_PAGE_MASK;
+        iotlb->perm = perm;
+        fault = 0;
+        goto done;
+    }
+
     /* Translate using device directory / page table information. */
     fault = riscv_iommu_spa_fetch(s, ctx, iotlb);
 
+    if (!fault && iotlb->target_as == &s->trap_as) {
+        /* Do not cache trapped MSI translations */
+        goto done;
+    }
+
+    if (!fault && iotlb->translated_addr != iotlb->iova && enable_cache) {
+        iot = g_new0(RISCVIOMMUEntry, 1);
+        iot->iova = PPN_DOWN(iotlb->iova);
+        iot->phys = PPN_DOWN(iotlb->translated_addr);
+        iot->pscid = get_field(ctx->ta, RISCV_IOMMU_DC_TA_PSCID);
+        iot->perm = iotlb->perm;
+        riscv_iommu_iot_update(s, iot_cache, iot);
+    }
+
+done:
+    g_hash_table_unref(iot_cache);
+
     if (enable_pri && fault) {
         struct riscv_iommu_pq_record pr = {0};
         if (enable_pasid) {
@@ -907,13 +1056,40 @@  static void riscv_iommu_process_cq_tail(RISCVIOMMUState *s)
             if (cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_PSCV) {
                 /* illegal command arguments IOTINVAL.GVMA & PSCV == 1 */
                 goto cmd_ill;
+            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_GV)) {
+                /* invalidate all cache mappings */
+                func = __iot_inval_all;
+            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_AV)) {
+                /* invalidate cache matching GSCID */
+                func = __iot_inval_gscid;
+            } else {
+                /* invalidate cache matching GSCID and ADDR (GPA) */
+                func = __iot_inval_gscid_gpa;
             }
-            /* translation cache not implemented yet */
+            riscv_iommu_iot_inval(s, func,
+                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_GSCID), 0,
+                cmd.dword1 & TARGET_PAGE_MASK);
             break;
 
         case RISCV_IOMMU_CMD(RISCV_IOMMU_CMD_IOTINVAL_FUNC_VMA,
                              RISCV_IOMMU_CMD_IOTINVAL_OPCODE):
-            /* translation cache not implemented yet */
+            if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_GV)) {
+                /* invalidate all cache mappings, simplified model */
+                func = __iot_inval_all;
+            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_PSCV)) {
+                /* invalidate cache matching GSCID, simplified model */
+                func = __iot_inval_gscid;
+            } else if (!(cmd.dword0 & RISCV_IOMMU_CMD_IOTINVAL_AV)) {
+                /* invalidate cache matching GSCID and PSCID */
+                func = __iot_inval_pscid;
+            } else {
+                /* invalidate cache matching GSCID and PSCID and ADDR (IOVA) */
+                func = __iot_inval_pscid_iova;
+            }
+            riscv_iommu_iot_inval(s, func,
+                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_GSCID),
+                get_field(cmd.dword0, RISCV_IOMMU_CMD_IOTINVAL_PSCID),
+                cmd.dword1 & TARGET_PAGE_MASK);
             break;
 
         case RISCV_IOMMU_CMD(RISCV_IOMMU_CMD_IODIR_FUNC_INVAL_DDT,
@@ -1410,6 +1586,8 @@  static void riscv_iommu_realize(DeviceState *dev, Error **errp)
     /* Device translation context cache */
     s->ctx_cache = g_hash_table_new_full(__ctx_hash, __ctx_equal,
                                          g_free, NULL);
+    s->iot_cache = g_hash_table_new_full(__iot_hash, __iot_equal,
+                                         g_free, NULL);
 
     s->iommus.le_next = NULL;
     s->iommus.le_prev = NULL;
@@ -1423,6 +1601,7 @@  static void riscv_iommu_unrealize(DeviceState *dev)
     RISCVIOMMUState *s = RISCV_IOMMU(dev);
 
     qemu_mutex_destroy(&s->core_lock);
+    g_hash_table_unref(s->iot_cache);
     g_hash_table_unref(s->ctx_cache);
 }
 
@@ -1430,6 +1609,8 @@  static Property riscv_iommu_properties[] = {
     DEFINE_PROP_UINT32("version", RISCVIOMMUState, version,
         RISCV_IOMMU_SPEC_DOT_VER),
     DEFINE_PROP_UINT32("bus", RISCVIOMMUState, bus, 0x0),
+    DEFINE_PROP_UINT32("ioatc-limit", RISCVIOMMUState, iot_limit,
+        LIMIT_CACHE_IOT),
     DEFINE_PROP_BOOL("intremap", RISCVIOMMUState, enable_msi, TRUE),
     DEFINE_PROP_BOOL("off", RISCVIOMMUState, enable_off, TRUE),
     DEFINE_PROP_LINK("downstream-mr", RISCVIOMMUState, target_mr,
@@ -1482,7 +1663,7 @@  static IOMMUTLBEntry riscv_iommu_memory_region_translate(
         /* Translation disabled or invalid. */
         iotlb.addr_mask = 0;
         iotlb.perm = IOMMU_NONE;
-    } else if (riscv_iommu_translate(as->iommu, ctx, &iotlb)) {
+    } else if (riscv_iommu_translate(as->iommu, ctx, &iotlb, true)) {
         /* Translation disabled or fault reported. */
         iotlb.addr_mask = 0;
         iotlb.perm = IOMMU_NONE;
diff --git a/hw/riscv/riscv-iommu.h b/hw/riscv/riscv-iommu.h
index 31d3907d33..3afee9f3e8 100644
--- a/hw/riscv/riscv-iommu.h
+++ b/hw/riscv/riscv-iommu.h
@@ -68,6 +68,8 @@  struct RISCVIOMMUState {
     MemoryRegion trap_mr;
 
     GHashTable *ctx_cache;          /* Device translation Context Cache */
+    GHashTable *iot_cache;          /* IO Translated Address Cache */
+    unsigned iot_limit;             /* IO Translation Cache size limit */
 
     /* MMIO Hardware Interface */
     MemoryRegion regs_mr;