diff mbox series

[2/2] jit: function pointers: bring the C++ API more on par with C

Message ID 20220131163440.8380-2-marc@nieper-wisskirchen.de
State New
Headers show
Series [1/2] jit: structs/unions: bring the C++ API more on par with C | expand

Commit Message

Marc Nieper-Wißkirchen Jan. 31, 2022, 4:34 p.m. UTC
This patch adds yet missing support for creating function pointer
types and to call functions through functions pointers in the C++ API
of libgccjit.

The method gccjit::context::new_call is overloaded so that it accepts
an rvalue instead of a function argument.  Instead of creating ad-hoc
specializations for 0 to 6 arguments, another overload taking an
initializer list instead of a vector is added.

gcc/jit/ChangeLog:

	* docs/cp/topics/expressions.rst: Updated.
	* docs/cp/topics/types.rst: Updated.
	* libgccjit++.h (context::new_function_ptr_type):
          New method.
	  (context::new_call): New method overloads.

gcc/testsuite/ChangeLog:

	* jit.dg/test-calling-function-ptr.cc: New test.
---
 gcc/jit/docs/cp/topics/expressions.rst        | 34 +++++--
 gcc/jit/docs/cp/topics/types.rst              | 16 ++++
 gcc/jit/libgccjit++.h                         | 70 ++++++++++++++
 .../jit.dg/test-calling-function-ptr.cc       | 93 +++++++++++++++++++
 4 files changed, 204 insertions(+), 9 deletions(-)
 create mode 100644 gcc/testsuite/jit.dg/test-calling-function-ptr.cc
diff mbox series

Patch

diff --git a/gcc/jit/docs/cp/topics/expressions.rst b/gcc/jit/docs/cp/topics/expressions.rst
index 3e9534790a3..d6c8f0de156 100644
--- a/gcc/jit/docs/cp/topics/expressions.rst
+++ b/gcc/jit/docs/cp/topics/expressions.rst
@@ -119,8 +119,8 @@  Vector expressions
 ******************
 
 .. function:: gccjit::rvalue \
-	      gccjit::context::new_rvalue (gccjit::type vector_type, \
-	                                   std::vector<gccjit::rvalue> elements) const
+              gccjit::context::new_rvalue (gccjit::type vector_type, \
+                                           std::vector<gccjit::rvalue> elements) const
 
    Given a vector type, and a vector of scalar rvalue elements, generate a
    vector rvalue.
@@ -454,13 +454,13 @@  The most concise way to spell them is with overloaded operators:
 
 Function calls
 **************
-.. function:: gcc_jit_rvalue *\
-              gcc_jit_context_new_call (gcc_jit_context *ctxt,\
-                                        gcc_jit_location *loc,\
-                                        gcc_jit_function *func,\
-                                        int numargs , gcc_jit_rvalue **args)
 
-   Given a function and the given table of argument rvalues, construct a
+.. function:: gccjit::rvalue \
+              gccjit::context::new_call (gccjit::function func, \
+                                         std::vector <gccjit::rvalue> &args, \
+                                         gccjit::location loc)
+
+   Given a function and the given vector of argument rvalues, construct a
    call to the function, with the result as an rvalue.
 
    .. note::
@@ -480,11 +480,27 @@  Function calls
          /* Add "(void)printf (arg0, arg1);".  */
          block.add_eval (ctxt.new_call (printf_func, arg0, arg1));
 
+.. function:: gccjit::rvalue \
+              gccjit::context::new_call (gccjit::rvalue fn_ptr, \
+                                         std::vector <gccjit::rvalue> &args, \
+                                         gccjit::location loc)
+
+   Given a function and the given vector of argument rvalues, construct a
+   call to the function, with the result as an rvalue.
+
+.. function:: gccjit::rvalue \
+   gccjit::context::new_call (gccjit::rvalue fn_ptr, \
+                              std::initializer_list <gccjit::rvalue> args, \
+                              gccjit::location loc)
+
+   Given a function and the given list of argument rvalues, construct a
+   call to the function, with the result as an rvalue.
+
 Function pointers
 *****************
 
 .. function:: gccjit::rvalue \
-	      gccjit::function::get_address (gccjit::location loc)
+              gccjit::function::get_address (gccjit::location loc)
 
    Get the address of a function as an rvalue, of function pointer
    type.
diff --git a/gcc/jit/docs/cp/topics/types.rst b/gcc/jit/docs/cp/topics/types.rst
index f41f504da6a..8c677d74834 100644
--- a/gcc/jit/docs/cp/topics/types.rst
+++ b/gcc/jit/docs/cp/topics/types.rst
@@ -216,3 +216,19 @@  You can model C `struct` types by creating :class:`gccjit::struct_` and
 					       gccjit::location loc)
 
    Construct a new union type, with the given name and fields.
+
+
+Function pointer types
+----------------------
+
+.. function::  gccjit::type \
+   gccjit::context::new_function_ptr_type (gccjit::type return_type, \
+                                           std::vector <type> &param_types, \
+                                           int is_variadic, \
+                                           gccjit::location loc)
+
+   Generate a :class:`type` for a function pointer with the given
+   return type and parameters.
+
+   Each of `param_types` must be non-`void`; `return_type` may be
+   `void`.
diff --git a/gcc/jit/libgccjit++.h b/gcc/jit/libgccjit++.h
index 17b10bc55c3..fa06c9cc4e1 100644
--- a/gcc/jit/libgccjit++.h
+++ b/gcc/jit/libgccjit++.h
@@ -169,6 +169,11 @@  namespace gccjit
 			 std::vector <field> &fields,
 			 location loc = location ());
 
+    type new_function_ptr_type (type return_type,
+				std::vector <type> &param_types,
+				int is_variadic,
+				location loc = location ());
+
     param new_param (type type_,
 		     const std::string &name,
 		     location loc = location ());
@@ -324,6 +329,11 @@  namespace gccjit
 		     rvalue arg3, rvalue arg4, rvalue arg5,
 		     location loc = location ());
 
+    rvalue new_call (rvalue fn_ptr, std::vector <rvalue> &args,
+		     location loc = location ());
+    rvalue new_call (rvalue fn_ptr, std::initializer_list <rvalue> args,
+		     location loc = location ());
+
     rvalue new_cast (rvalue expr,
 		     type type_,
 		     location loc = location ());
@@ -893,6 +903,29 @@  context::new_union_type (const std::string &name,
 					       as_array_of_ptrs));
 }
 
+inline type
+context::new_function_ptr_type (type return_type,
+				std::vector <type> &param_types,
+				int is_variadic,
+				location loc)
+{
+  /* Treat std::vector as an array, relying on it not being resized: */
+  type *as_array_of_wrappers = &param_types[0];
+
+  /* Treat the array as being of the underlying pointers, relying on
+     the wrapper type being such a pointer internally.	*/
+  gcc_jit_type **as_array_of_ptrs =
+    reinterpret_cast<gcc_jit_type **> (as_array_of_wrappers);
+
+  return type
+    (gcc_jit_context_new_function_ptr_type (m_inner_ctxt,
+					    loc.get_inner_location (),
+					    return_type.get_inner_type (),
+					    param_types.size (),
+					    as_array_of_ptrs,
+					    is_variadic));
+}
+
 inline param
 context::new_param (type type_,
 		    const std::string &name,
@@ -1241,6 +1274,7 @@  context::new_call (function func,
 				   args.size (),
 				   as_array_of_ptrs);
 }
+
 inline rvalue
 context::new_call (function func,
 		   location loc)
@@ -1322,6 +1356,42 @@  context::new_call (function func,
   return new_call (func, args, loc);
 }
 
+inline rvalue
+context::new_call (rvalue fn_ptr, std::vector <rvalue> &args,
+		   location loc)
+{
+  /* Treat std::vector as an array, relying on it not being resized: */
+  rvalue *as_array_of_wrappers = &args[0];
+
+  /* Treat the array as being of the underlying pointers, relying on
+     the wrapper type being such a pointer internally.	*/
+  gcc_jit_rvalue **as_array_of_ptrs =
+    reinterpret_cast<gcc_jit_rvalue **> (as_array_of_wrappers);
+  return gcc_jit_context_new_call_through_ptr (m_inner_ctxt,
+					       loc.get_inner_location (),
+					       fn_ptr.get_inner_rvalue (),
+					       args.size (),
+					       as_array_of_ptrs);
+}
+
+inline rvalue
+context::new_call (rvalue fn_ptr, std::initializer_list <rvalue> args,
+		   location loc)
+{
+  /* We rely on the underlying C API not modifying the values. */
+  rvalue *as_array_of_wrappers = const_cast <rvalue *> (args.begin ());
+
+  /* Treat the array as being of the underlying pointers, relying on
+     the wrapper type being such a pointer internally.	*/
+  gcc_jit_rvalue **as_array_of_ptrs =
+    reinterpret_cast<gcc_jit_rvalue **> (as_array_of_wrappers);
+  return gcc_jit_context_new_call_through_ptr (m_inner_ctxt,
+					       loc.get_inner_location (),
+					       fn_ptr.get_inner_rvalue (),
+					       args.size (),
+					       as_array_of_ptrs);
+}
+
 inline rvalue
 context::new_cast (rvalue expr,
 		   type type_,
diff --git a/gcc/testsuite/jit.dg/test-calling-function-ptr.cc b/gcc/testsuite/jit.dg/test-calling-function-ptr.cc
new file mode 100644
index 00000000000..f83b08b8869
--- /dev/null
+++ b/gcc/testsuite/jit.dg/test-calling-function-ptr.cc
@@ -0,0 +1,93 @@ 
+#include <stdlib.h>
+#include <stdio.h>
+
+#include "libgccjit++.h"
+
+#include "harness.h"
+
+void
+create_code (gcc_jit_context *c_ctxt, void *user_data)
+{
+  /* Let's try to inject the equivalent of:
+
+     void
+     test_calling_function_ptr (void (*fn_ptr) (int, int, int) fn_ptr,
+                                int a)
+     {
+        fn_ptr (a * 3, a * 4, a * 5);
+     }
+  */
+
+  gccjit::context ctxt (c_ctxt);
+
+  int i;
+  gccjit::type void_type =
+    ctxt.get_type (GCC_JIT_TYPE_VOID);
+  gccjit::type int_type =
+    ctxt.get_type (GCC_JIT_TYPE_INT);
+
+  /* Build the function ptr type.  */
+  std::vector param_types {
+    int_type,
+    int_type,
+    int_type
+  };
+  gccjit::type fn_ptr_type =
+    ctxt.new_function_ptr_type (void_type, param_types, 0);
+
+  /* Build the test_fn.  */
+  gccjit::param param_fn_ptr =
+    ctxt.new_param (fn_ptr_type, "fn_ptr");
+  gccjit::param param_a =
+    ctxt.new_param (int_type, "a");
+
+  std::vector params {param_fn_ptr, param_a};
+  gccjit::function test_fn =
+    ctxt.new_function (GCC_JIT_FUNCTION_EXPORTED,
+                       void_type,
+                       "test_calling_function_ptr",
+                       params,
+                       0);
+  /* "a * 3, a * 4, a * 5"  */
+  gccjit::rvalue args[3];
+  for (i = 0; i < 3; i++)
+    args[i] = param_a * ctxt.new_rvalue (int_type, i + 3);
+  gccjit::block block = test_fn.new_block ();
+  block.add_eval (ctxt.new_call (param_fn_ptr,
+                                 {args[0], args[1], args[2]}));
+  block.end_with_return ();
+}
+
+static int called_through_ptr_with[3];
+
+static void
+function_called_through_fn_ptr (int i, int j, int k)
+{
+  called_through_ptr_with[0] = i;
+  called_through_ptr_with[1] = j;
+  called_through_ptr_with[2] = k;
+}
+
+void
+verify_code (gcc_jit_context *ctxt, gcc_jit_result *result)
+{
+  typedef void (*fn_type) (void (*fn_ptr) (int, int, int),
+			   int);
+  CHECK_NON_NULL (result);
+
+  fn_type test_caller =
+    (fn_type)gcc_jit_result_get_code (result, "test_calling_function_ptr");
+  CHECK_NON_NULL (test_caller);
+
+  called_through_ptr_with[0] = 0;
+  called_through_ptr_with[1] = 0;
+  called_through_ptr_with[2] = 0;
+
+  /* Call the JIT-generated function.  */
+  test_caller (function_called_through_fn_ptr, 5);
+
+  /* Verify that it correctly called "function_called_through_fn_ptr".  */
+  CHECK_VALUE (called_through_ptr_with[0], 15);
+  CHECK_VALUE (called_through_ptr_with[1], 20);
+  CHECK_VALUE (called_through_ptr_with[2], 25);
+}