diff mbox series

[PULL,26/27] scripts/decodetree: Implement named field support

Message ID 20230530185949.410208-27-richard.henderson@linaro.org
State New
Headers show
Series [PULL,01/27] tcg: Fix register move type in tcg_out_ld_helper_ret | expand

Commit Message

Richard Henderson May 30, 2023, 6:59 p.m. UTC
From: Peter Maydell <peter.maydell@linaro.org>

Implement support for named fields, i.e.  where one field is defined
in terms of another, rather than directly in terms of bits extracted
from the instruction.

The new method referenced_fields() on all the Field classes returns a
list of fields that this field references.  This just passes through,
except for the new NamedField class.

We can then use referenced_fields() to:
 * construct a list of 'dangling references' for a format or
   pattern, which is the fields that the format/pattern uses but
   doesn't define itself
 * do a topological sort, so that we output "field = value"
   assignments in an order that means that we assign a field before
   we reference it in a subsequent assignment
 * check when we output the code for a pattern whether we need to
   fill in the format fields before or after the pattern fields, and
   do other error checking

Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
Reviewed-by: Richard Henderson <richard.henderson@linaro.org>
Message-Id: <20230523120447.728365-6-peter.maydell@linaro.org>
---
 scripts/decodetree.py | 145 ++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 139 insertions(+), 6 deletions(-)
diff mbox series

Patch

diff --git a/scripts/decodetree.py b/scripts/decodetree.py
index db019a25c6..13db585d04 100644
--- a/scripts/decodetree.py
+++ b/scripts/decodetree.py
@@ -290,6 +290,9 @@  def str_extract(self, lvalue_formatter):
         s = 's' if self.sign else ''
         return f'{s}extract{bitop_width}(insn, {self.pos}, {self.len})'
 
+    def referenced_fields(self):
+        return []
+
     def __eq__(self, other):
         return self.sign == other.sign and self.mask == other.mask
 
@@ -321,6 +324,12 @@  def str_extract(self, lvalue_formatter):
             pos += f.len
         return ret
 
+    def referenced_fields(self):
+        l = []
+        for f in self.subs:
+            l.extend(f.referenced_fields())
+        return l
+
     def __ne__(self, other):
         if len(self.subs) != len(other.subs):
             return True
@@ -347,6 +356,9 @@  def __str__(self):
     def str_extract(self, lvalue_formatter):
         return str(self.value)
 
+    def referenced_fields(self):
+        return []
+
     def __cmp__(self, other):
         return self.value - other.value
 # end ConstField
@@ -367,6 +379,9 @@  def str_extract(self, lvalue_formatter):
         return (self.func + '(ctx, '
                 + self.base.str_extract(lvalue_formatter) + ')')
 
+    def referenced_fields(self):
+        return self.base.referenced_fields()
+
     def __eq__(self, other):
         return self.func == other.func and self.base == other.base
 
@@ -388,6 +403,9 @@  def __str__(self):
     def str_extract(self, lvalue_formatter):
         return self.func + '(ctx)'
 
+    def referenced_fields(self):
+        return []
+
     def __eq__(self, other):
         return self.func == other.func
 
@@ -395,6 +413,32 @@  def __ne__(self, other):
         return not self.__eq__(other)
 # end ParameterField
 
+class NamedField:
+    """Class representing a field already named in the pattern"""
+    def __init__(self, name, sign, len):
+        self.mask = 0
+        self.sign = sign
+        self.len = len
+        self.name = name
+
+    def __str__(self):
+        return self.name
+
+    def str_extract(self, lvalue_formatter):
+        global bitop_width
+        s = 's' if self.sign else ''
+        lvalue = lvalue_formatter(self.name)
+        return f'{s}extract{bitop_width}({lvalue}, 0, {self.len})'
+
+    def referenced_fields(self):
+        return [self.name]
+
+    def __eq__(self, other):
+        return self.name == other.name
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+# end NamedField
 
 class Arguments:
     """Class representing the extracted fields of a format"""
@@ -418,7 +462,6 @@  def output_def(self):
             output('} ', self.struct_name(), ';\n\n')
 # end Arguments
 
-
 class General:
     """Common code between instruction formats and instruction patterns"""
     def __init__(self, name, lineno, base, fixb, fixm, udfm, fldm, flds, w):
@@ -432,6 +475,7 @@  def __init__(self, name, lineno, base, fixb, fixm, udfm, fldm, flds, w):
         self.fieldmask = fldm
         self.fields = flds
         self.width = w
+        self.dangling = None
 
     def __str__(self):
         return self.name + ' ' + str_match_bits(self.fixedbits, self.fixedmask)
@@ -439,10 +483,51 @@  def __str__(self):
     def str1(self, i):
         return str_indent(i) + self.__str__()
 
+    def dangling_references(self):
+        # Return a list of all named references which aren't satisfied
+        # directly by this format/pattern. This will be either:
+        #  * a format referring to a field which is specified by the
+        #    pattern(s) using it
+        #  * a pattern referring to a field which is specified by the
+        #    format it uses
+        #  * a user error (referring to a field that doesn't exist at all)
+        if self.dangling is None:
+            # Compute this once and cache the answer
+            dangling = []
+            for n, f in self.fields.items():
+                for r in f.referenced_fields():
+                    if r not in self.fields:
+                        dangling.append(r)
+            self.dangling = dangling
+        return self.dangling
+
     def output_fields(self, indent, lvalue_formatter):
+        # We use a topological sort to ensure that any use of NamedField
+        # comes after the initialization of the field it is referencing.
+        graph = {}
         for n, f in self.fields.items():
-            output(indent, lvalue_formatter(n), ' = ',
-                   f.str_extract(lvalue_formatter), ';\n')
+            refs = f.referenced_fields()
+            graph[n] = refs
+
+        try:
+            ts = TopologicalSorter(graph)
+            for n in ts.static_order():
+                # We only want to emit assignments for the keys
+                # in our fields list, not for anything that ends up
+                # in the tsort graph only because it was referenced as
+                # a NamedField.
+                try:
+                    f = self.fields[n]
+                    output(indent, lvalue_formatter(n), ' = ',
+                           f.str_extract(lvalue_formatter), ';\n')
+                except KeyError:
+                    pass
+        except CycleError as e:
+            # The second element of args is a list of nodes which form
+            # a cycle (there might be others too, but only one is reported).
+            # Pretty-print it to tell the user.
+            cycle = ' => '.join(e.args[1])
+            error(self.lineno, 'field definitions form a cycle: ' + cycle)
 # end General
 
 
@@ -477,10 +562,36 @@  def output_code(self, i, extracted, outerbits, outermask):
         ind = str_indent(i)
         arg = self.base.base.name
         output(ind, '/* ', self.file, ':', str(self.lineno), ' */\n')
+        # We might have named references in the format that refer to fields
+        # in the pattern, or named references in the pattern that refer
+        # to fields in the format. This affects whether we extract the fields
+        # for the format before or after the ones for the pattern.
+        # For simplicity we don't allow cross references in both directions.
+        # This is also where we catch the syntax error of referring to
+        # a nonexistent field.
+        fmt_refs = self.base.dangling_references()
+        for r in fmt_refs:
+            if r not in self.fields:
+                error(self.lineno, f'format refers to undefined field {r}')
+        pat_refs = self.dangling_references()
+        for r in pat_refs:
+            if r not in self.base.fields:
+                error(self.lineno, f'pattern refers to undefined field {r}')
+        if pat_refs and fmt_refs:
+            error(self.lineno, ('pattern that uses fields defined in format '
+                                'cannot use format that uses fields defined '
+                                'in pattern'))
+        if fmt_refs:
+            # pattern fields first
+            self.output_fields(ind, lambda n: 'u.f_' + arg + '.' + n)
+            assert not extracted, "dangling fmt refs but it was already extracted"
         if not extracted:
             output(ind, self.base.extract_name(),
                    '(ctx, &u.f_', arg, ', insn);\n')
-        self.output_fields(ind, lambda n: 'u.f_' + arg + '.' + n)
+        if not fmt_refs:
+            # pattern fields last
+            self.output_fields(ind, lambda n: 'u.f_' + arg + '.' + n)
+
         output(ind, 'if (', translate_prefix, '_', self.name,
                '(ctx, &u.f_', arg, ')) return true;\n')
 
@@ -626,8 +737,10 @@  def output_code(self, i, extracted, outerbits, outermask):
         ind = str_indent(i)
 
         # If we identified all nodes below have the same format,
-        # extract the fields now.
-        if not extracted and self.base:
+        # extract the fields now. But don't do it if the format relies
+        # on named fields from the insn pattern, as those won't have
+        # been initialised at this point.
+        if not extracted and self.base and not self.base.dangling_references():
             output(ind, self.base.extract_name(),
                    '(ctx, &u.f_', self.base.base.name, ', insn);\n')
             extracted = True
@@ -749,6 +862,7 @@  def parse_field(lineno, name, toks):
     """Parse one instruction field from TOKS at LINENO"""
     global fields
     global insnwidth
+    global re_C_ident
 
     # A "simple" field will have only one entry;
     # a "multifield" will have several.
@@ -763,6 +877,25 @@  def parse_field(lineno, name, toks):
             func = func[1]
             continue
 
+        if re.fullmatch(re_C_ident + ':s[0-9]+', t):
+            # Signed named field
+            subtoks = t.split(':')
+            n = subtoks[0]
+            le = int(subtoks[1])
+            f = NamedField(n, True, le)
+            subs.append(f)
+            width += le
+            continue
+        if re.fullmatch(re_C_ident + ':[0-9]+', t):
+            # Unsigned named field
+            subtoks = t.split(':')
+            n = subtoks[0]
+            le = int(subtoks[1])
+            f = NamedField(n, False, le)
+            subs.append(f)
+            width += le
+            continue
+
         if re.fullmatch('[0-9]+:s[0-9]+', t):
             # Signed field extract
             subtoks = t.split(':s')