diff mbox series

[11/17] aarch64: Add support for ZA storage

Message ID 20240511115400.7587-12-richard.henderson@linaro.org
State New
Headers show
Series RISU misc updates | expand

Commit Message

Richard Henderson May 11, 2024, 11:53 a.m. UTC
Require NVL == SVL on startup, to make it easier to manage reginfo.
Most of the time PSTATE.SM would be active with PSTATE.ZA anyway,
for any non-trivial SME testing.

Extend saved storage only when PSTATE.ZA is active.
Use a carefully reserved uint16_t for saving SVCR.

Reviewed-by: Peter Maydell <peter.maydell@linaro.org>
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
---
 risu_reginfo_aarch64.h |  52 ++++++++++++-
 risu_reginfo_aarch64.c | 161 ++++++++++++++++++++++++++++++++++++-----
 2 files changed, 191 insertions(+), 22 deletions(-)
diff mbox series

Patch

diff --git a/risu_reginfo_aarch64.h b/risu_reginfo_aarch64.h
index 536c12b..097b7ad 100644
--- a/risu_reginfo_aarch64.h
+++ b/risu_reginfo_aarch64.h
@@ -21,6 +21,43 @@ 
 #define SVE_VQ_MAX 16
 
 #define ROUND_UP(SIZE, POW2)    (((SIZE) + (POW2) - 1) & -(POW2))
+
+#ifdef ZA_MAGIC
+/* System headers have all Streaming SVE definitions. */
+typedef struct sve_context risu_sve_context;
+typedef struct za_context  risu_za_context;
+#else
+#define ZA_MAGIC         0x54366345
+#define SVE_SIG_FLAG_SM  1
+
+/* System headers missing flags field. */
+typedef struct {
+    struct _aarch64_ctx head;
+    uint16_t vl;
+    uint16_t flags;
+    uint16_t reserved[2];
+} risu_sve_context;
+
+typedef struct {
+    struct _aarch64_ctx head;
+    uint16_t vl;
+    uint16_t reserved[3];
+} risu_za_context;
+
+#define ZA_SIG_REGS_OFFSET \
+    ROUND_UP(sizeof(risu_za_context), SVE_VQ_BYTES)
+
+#define ZA_SIG_REGS_SIZE(vq) \
+    ((vq) * (vq) * SVE_VQ_BYTES * SVE_VQ_BYTES)
+
+#define ZA_SIG_ZAV_OFFSET(vq, n) \
+    (ZA_SIG_REGS_OFFSET + (SVE_SIG_ZREG_SIZE(vq) * n))
+
+#define ZA_SIG_CONTEXT_SIZE(vq) \
+    (ZA_SIG_REGS_OFFSET + ZA_SIG_REGS_SIZE(vq))
+
+#endif /* ZA_MAGIC */
+
 #define RISU_SVE_REGS_SIZE(VQ)  ROUND_UP(SVE_SIG_REGS_SIZE(VQ), 16)
 #define RISU_SIMD_REGS_SIZE     (32 * 16)
 
@@ -36,12 +73,16 @@  struct reginfo {
     uint32_t fpsr;
     uint32_t fpcr;
     uint16_t sve_vl;
-    uint16_t reserved;
+    uint16_t svcr;
 
-    char extra[RISU_SVE_REGS_SIZE(SVE_VQ_MAX)]
+    char extra[RISU_SVE_REGS_SIZE(SVE_VQ_MAX) +
+               ZA_SIG_REGS_SIZE(SVE_VQ_MAX)]
         __attribute__((aligned(16)));
 };
 
+#define SVCR_SM  1
+#define SVCR_ZA  2
+
 static inline uint64_t *reginfo_vreg(struct reginfo *ri, int i)
 {
     return (uint64_t *)&ri->extra[i * 16];
@@ -59,4 +100,11 @@  static inline uint16_t *reginfo_preg(struct reginfo *ri, int vq, int i)
                                   SVE_SIG_REGS_OFFSET];
 }
 
+static inline uint64_t *reginfo_zav(struct reginfo *ri, int vq, int i)
+{
+    return (uint64_t *)&ri->extra[RISU_SVE_REGS_SIZE(vq) +
+                                  ZA_SIG_ZAV_OFFSET(vq, i) -
+                                  ZA_SIG_REGS_OFFSET];
+}
+
 #endif /* RISU_REGINFO_AARCH64_H */
diff --git a/risu_reginfo_aarch64.c b/risu_reginfo_aarch64.c
index 86e70ab..67a2999 100644
--- a/risu_reginfo_aarch64.c
+++ b/risu_reginfo_aarch64.c
@@ -25,25 +25,44 @@ 
 #include "risu.h"
 #include "risu_reginfo_aarch64.h"
 
+#ifndef PR_SME_SET_VL
+#define PR_SME_SET_VL 63
+#endif
+
 /* Should we test SVE register state */
 static int test_sve;
+static int test_za;
 static const struct option extra_opts[] = {
     {"test-sve", required_argument, NULL, FIRST_ARCH_OPT },
+    {"test-za", required_argument, NULL, FIRST_ARCH_OPT + 1 },
     {0, 0, 0, 0}
 };
 
 const struct option * const arch_long_opts = &extra_opts[0];
 const char * const arch_extra_help
-    = "  --test-sve=<vq>        Compare SVE registers with VQ\n";
+    = "  --test-sve=<vq>        Compare SVE registers with VQ\n"
+      "  --test-za=<vq>         Compare ZA storage with VQ\n";
 
 void process_arch_opt(int opt, const char *arg)
 {
-    assert(opt == FIRST_ARCH_OPT);
-    test_sve = strtol(arg, 0, 10);
-
-    if (test_sve <= 0 || test_sve > SVE_VQ_MAX) {
-        fprintf(stderr, "Invalid value for VQ (1-%d)\n", SVE_VQ_MAX);
-        exit(EXIT_FAILURE);
+    switch (opt) {
+    case FIRST_ARCH_OPT:
+        test_sve = strtol(arg, 0, 10);
+        if (test_sve <= 0 || test_sve > SVE_VQ_MAX) {
+            fprintf(stderr, "Invalid value for SVE VQ (1-%d)\n", SVE_VQ_MAX);
+            exit(EXIT_FAILURE);
+        }
+        break;
+    case FIRST_ARCH_OPT + 1:
+        test_za = strtol(arg, 0, 10);
+        if (test_za <= 0 || test_za > SVE_VQ_MAX
+            || (test_za & (test_za - 1))) {
+            fprintf(stderr, "Invalid value for ZA VQ (1-%d)\n", SVE_VQ_MAX);
+            exit(EXIT_FAILURE);
+        }
+        break;
+    default:
+        abort();
     }
 }
 
@@ -51,6 +70,31 @@  void arch_init(void)
 {
     long want, got;
 
+    if (test_za) {
+        /*
+         * For now, reginfo requires NVL == SVL.
+         * There doesn't seem to be much advantage to differing.
+         */
+        if (test_sve && test_sve != test_za) {
+            fprintf(stderr, "Mismatched values for SVE and ZA VQ\n");
+            exit(EXIT_FAILURE);
+        }
+
+        want = sve_vl_from_vq(test_za);
+        got = prctl(PR_SME_SET_VL, want);
+        if (want != got) {
+            if (got >= 0) {
+                fprintf(stderr, "Unsupported VQ for ZA (%d != %d)\n",
+                        test_za, (int)sve_vq_from_vl(got));
+            } else if (errno == EINVAL) {
+                fprintf(stderr, "System does not support SME\n");
+            } else {
+                perror("prctl PR_SME_SET_VL");
+            }
+            exit(EXIT_FAILURE);
+        }
+    }
+
     if (test_sve) {
         want = sve_vl_from_vq(test_sve);
         got = prctl(PR_SVE_SET_VL, want);
@@ -75,6 +119,9 @@  int reginfo_size(struct reginfo *ri)
     if (ri->sve_vl) {
         int vq = sve_vq_from_vl(ri->sve_vl);
         size += RISU_SVE_REGS_SIZE(vq);
+        if (ri->svcr & SVCR_ZA) {
+            size += ZA_SIG_REGS_SIZE(vq);
+        }
     } else {
         size += RISU_SIMD_REGS_SIZE;
     }
@@ -84,10 +131,11 @@  int reginfo_size(struct reginfo *ri)
 /* reginfo_init: initialize with a ucontext */
 void reginfo_init(struct reginfo *ri, ucontext_t *uc, void *siaddr)
 {
-    int i;
+    int i, vq;
     struct _aarch64_ctx *ctx, *extra = NULL;
     struct fpsimd_context *fp = NULL;
-    struct sve_context *sve = NULL;
+    risu_sve_context *sve = NULL;
+    risu_za_context *za = NULL;
 
     /* necessary to be able to compare with memcmp later */
     memset(ri, 0, sizeof(*ri));
@@ -112,6 +160,9 @@  void reginfo_init(struct reginfo *ri, ucontext_t *uc, void *siaddr)
         case SVE_MAGIC:
             sve = (void *)ctx;
             break;
+        case ZA_MAGIC:
+            za = (void *)ctx;
+            break;
         case EXTRA_MAGIC:
             extra = (void *)((struct extra_context *)(ctx))->datap;
             break;
@@ -134,21 +185,55 @@  void reginfo_init(struct reginfo *ri, ucontext_t *uc, void *siaddr)
     ri->fpsr = fp->fpsr;
     ri->fpcr = fp->fpcr;
 
-    if (test_sve) {
-        int vq = test_sve;
+    /*
+     * Note that arch_init required NVL==SVL, so test_sve and test_za
+     * are equal when non-zero.  We will verify this matches below.
+     */
+    vq = test_sve | test_za;
+    ri->sve_vl = sve_vl_from_vq(vq);
 
-        if (sve == NULL) {
-            fprintf(stderr, "risu_reginfo_aarch64: failed to get SVE state\n");
+    if (test_za) {
+        if (za == NULL) {
+            /* ZA_MAGIC is supposed to be present, even if empty. */
+            fprintf(stderr, "risu_reginfo_aarch64: missing ZA state\n");
             return;
         }
+        assert(za->head.size >= ZA_SIG_CONTEXT_SIZE(0));
 
-        if (sve->vl != sve_vl_from_vq(vq)) {
+        if (za->vl != ri->sve_vl) {
             fprintf(stderr, "risu_reginfo_aarch64: "
-                    "unexpected SVE state: %d != %d\n",
-                    sve->vl, sve_vl_from_vq(vq));
+                    "unexpected ZA VQ: %d != %d\n",
+                    za->vl, ri->sve_vl);
+            return;
+        }
+        if (za->head.size == ZA_SIG_CONTEXT_SIZE(0)) {
+            /* ZA storage is disabled. */
+        } else if (za->head.size < ZA_SIG_CONTEXT_SIZE(vq)) {
+            fprintf(stderr, "risu_reginfo_aarch64: "
+                    "failed to get complete ZA state\n");
+            return;
+        } else {
+            ri->svcr |= SVCR_ZA;
+            memcpy(reginfo_zav(ri, vq, 0), (char *)za + ZA_SIG_REGS_OFFSET,
+                   ZA_SIG_CONTEXT_SIZE(vq) - ZA_SIG_REGS_OFFSET);
+        }
+    }
+
+    if (test_sve) {
+        if (sve == NULL) {
+            /* SVE_MAGIC is supposed to be present, even if empty. */
+            fprintf(stderr, "risu_reginfo_aarch64: missing SVE state\n");
             return;
         }
 
+        if (sve->vl != ri->sve_vl) {
+            fprintf(stderr, "risu_reginfo_aarch64: "
+                    "unexpected SVE VQ: %d != %d\n",
+                    sve->vl, ri->sve_vl);
+            return;
+        }
+
+        ri->svcr |= sve->flags & SVE_SIG_FLAG_SM;
         if (sve->head.size <= SVE_SIG_CONTEXT_SIZE(0)) {
             /* Only AdvSIMD state is present. */
         } else if (sve->head.size < SVE_SIG_CONTEXT_SIZE(vq)) {
@@ -156,7 +241,6 @@  void reginfo_init(struct reginfo *ri, ucontext_t *uc, void *siaddr)
                     "failed to get complete SVE state\n");
             return;
         } else {
-            ri->sve_vl = sve->vl;
             memcpy(reginfo_zreg(ri, vq, 0),
                    (char *)sve + SVE_SIG_REGS_OFFSET,
                    SVE_SIG_REGS_SIZE(vq));
@@ -164,7 +248,18 @@  void reginfo_init(struct reginfo *ri, ucontext_t *uc, void *siaddr)
         }
     }
 
-    memcpy(reginfo_vreg(ri, 0), fp->vregs, RISU_SIMD_REGS_SIZE);
+    /*
+     * Be prepared for ZA state present but SVE state absent (VQ != 0).
+     * In which case, copy AdvSIMD vregs into the low portion of zregs;
+     * pregs remain all zero.
+     */
+    if (vq == 0) {
+        memcpy(reginfo_vreg(ri, 0), fp->vregs, RISU_SIMD_REGS_SIZE);
+    } else {
+        for (i = 0; i < 32; ++i) {
+            memcpy(reginfo_zreg(ri, vq, i), &fp->vregs[i], 16);
+        }
+    }
 }
 
 /* reginfo_is_eq: compare the reginfo structs, returns nonzero if equal */
@@ -248,9 +343,11 @@  void reginfo_dump(struct reginfo *ri, FILE * f)
     fprintf(f, "  fpcr   : %08x\n", ri->fpcr);
 
     if (ri->sve_vl) {
-        int vq = sve_vq_from_vl(ri->sve_vl);
+        int vl = ri->sve_vl;
+        int vq = sve_vq_from_vl(vl);
 
-        fprintf(f, "  vl     : %d\n", ri->sve_vl);
+        fprintf(f, "  vl     : %d\n", vl);
+        fprintf(f, "  svcr   : %d\n", ri->svcr);
 
         for (i = 0; i < SVE_NUM_ZREGS; i++) {
             uint64_t *z = reginfo_zreg(ri, vq, i);
@@ -270,6 +367,14 @@  void reginfo_dump(struct reginfo *ri, FILE * f)
             sve_dump_preg(f, vq, p);
             fprintf(f, "\n");
         }
+
+        if (ri->svcr & SVCR_ZA) {
+            for (i = 0; i < vl; ++i) {
+                uint64_t *z = reginfo_zav(ri, vq, i);
+                fprintf(f, "  ZA[%-3d]: ", i);
+                sve_dump_zreg(f, vq, z);
+            }
+        }
         return;
     }
 
@@ -322,6 +427,10 @@  void reginfo_dump_mismatch(struct reginfo *m, struct reginfo *a, FILE * f)
         fprintf(f, "  vl     : %d vs %d\n", m->sve_vl, a->sve_vl);
     }
 
+    if (m->svcr != a->svcr) {
+        fprintf(f, "  svcr   : %d vs %d\n", m->svcr, a->svcr);
+    }
+
     if (m->sve_vl) {
         int vq = sve_vq_from_vl(m->sve_vl);
 
@@ -347,6 +456,18 @@  void reginfo_dump_mismatch(struct reginfo *m, struct reginfo *a, FILE * f)
                 sve_dump_preg_diff(f, vq, pm, pa);
             }
         }
+
+        if (m->svcr & a->svcr & SVCR_ZA) {
+            for (i = 0; i < vq * 16; i++) {
+                uint64_t *zm = reginfo_zav(m, vq, i);
+                uint64_t *za = reginfo_zav(a, vq, i);
+
+                if (!sve_zreg_is_eq(vq, zm, za)) {
+                    fprintf(f, "  ZA[%-3d]: ", i);
+                    sve_dump_zreg_diff(f, vq, zm, za);
+                }
+            }
+        }
         return;
     }