diff mbox series

[Ada] Fix new CUDA kernel registration scheme

Message ID 20220530083233.GA210555@adacore.com
State New
Headers show
Series [Ada] Fix new CUDA kernel registration scheme | expand

Commit Message

Pierre-Marie de Rodat May 30, 2022, 8:32 a.m. UTC
Removal of the previous kernel registration scheme unearthed mistakes in
the new one, which were:
- The new kernel registration code relied on the binder expansion phase,
  which didn't happen because the registration code was already
  generated by the binder.
- The kernel handle passed to CUDA_Register_Function was the first eight
  bytes of the code of the host-side procedure representing the kernel
  rather than its address.

Tested on x86_64-pc-linux-gnu, committed on trunk

gcc/ada/

	* bindgen.adb (Gen_CUDA_Init): Remove code generating CUDA
	definitions.
	(Gen_CUDA_Defs): New function, generating definitions
	initialized by Gen_CUDA_Init.
	(Gen_Output_File_Ada): Call Gen_CUDA_Defs instead of
	Gen_CUDA_Init.
	(Gen_Adainit): Call Gen_CUDA_Init.
diff mbox series

Patch

diff --git a/gcc/ada/bindgen.adb b/gcc/ada/bindgen.adb
--- a/gcc/ada/bindgen.adb
+++ b/gcc/ada/bindgen.adb
@@ -311,8 +311,11 @@  package body Bindgen is
    procedure Gen_CodePeer_Wrapper;
    --  For CodePeer, generate wrapper which calls user-defined main subprogram
 
+   procedure Gen_CUDA_Defs;
+   --  Generate definitions needed in order to register kernels
+
    procedure Gen_CUDA_Init;
-   --  When CUDA registration code is needed.
+   --  Generate calls needed in order to register kernels
 
    procedure Gen_Elab_Calls (Elab_Order : Unit_Id_Array);
    --  Generate sequence of elaboration calls
@@ -1115,6 +1118,8 @@  package body Bindgen is
          WBI ("");
       end if;
 
+      Gen_CUDA_Init;
+
       Gen_Elab_Calls (Elab_Order);
 
       if not CodePeer_Mode then
@@ -1221,10 +1226,10 @@  package body Bindgen is
    end Gen_Bind_Env_String;
 
    -------------------
-   -- Gen_CUDA_Init --
+   -- Gen_CUDA_Defs --
    -------------------
 
-   procedure Gen_CUDA_Init is
+   procedure Gen_CUDA_Defs is
       Unit_Name : constant String :=
         Get_Name_String (Units.Table (First_Unit_Entry).Uname);
       Unit : constant String :=
@@ -1237,7 +1242,7 @@  package body Bindgen is
       WBI ("");
       WBI ("   ");
 
-      WBI ("   function CUDA_Register_Function");
+      WBI ("   procedure CUDA_Register_Function");
       WBI ("      (Fat_Binary_Handle : System.Address;");
       WBI ("       Func : System.Address;");
       WBI ("       Kernel_Name : Interfaces.C.Strings.chars_ptr;");
@@ -1247,7 +1252,7 @@  package body Bindgen is
       WBI ("       Nullptr2 : System.Address;");
       WBI ("       Nullptr3 : System.Address;");
       WBI ("       Nullptr4 : System.Address;");
-      WBI ("       Nullptr5 : System.Address) return Boolean;");
+      WBI ("       Nullptr5 : System.Address);");
       WBI ("   pragma Import");
       WBI ("     (Convention => C,");
       WBI ("      Entity => CUDA_Register_Function,");
@@ -1261,8 +1266,8 @@  package body Bindgen is
       WBI ("       Entity => CUDA_Register_Fat_Binary,");
       WBI ("       External_Name => ""__cudaRegisterFatBinary"");");
       WBI ("");
-      WBI ("   function CUDA_Register_Fat_Binary_End");
-      WBI ("     (Fat_Binary : System.Address) return Boolean;");
+      WBI ("   procedure CUDA_Register_Fat_Binary_End");
+      WBI ("     (Fat_Binary : System.Address);");
       WBI ("   pragma Import");
       WBI ("     (Convention => C,");
       WBI ("      Entity => CUDA_Register_Fat_Binary_End,");
@@ -1287,8 +1292,7 @@  package body Bindgen is
       WBI ("      Fat_Binary'Address,");
       WBI ("      System.Null_Address);");
       WBI ("");
-      WBI ("   Fat_Binary_Handle : System.Address :=");
-      WBI ("     CUDA_Register_Fat_Binary (Wrapper'Address);");
+      WBI ("   Fat_Binary_Handle : System.Address;");
       WBI ("");
 
       for K in CUDA_Kernels.First .. CUDA_Kernels.Last loop
@@ -1300,9 +1304,9 @@  package body Bindgen is
             --  K_Symbol is a unique identifier used to derive all symbol names
             --  related to kernel K.
 
-            Kernel_Addr : constant String := Kernel_Symbol & "_Addr";
-            --  Kernel_Addr is the name of the symbol representing the address
-            --  of the host-side procedure of the kernel. The address is
+            Kernel_Proc : constant String := Kernel_Symbol & "_Proc";
+            --  Kernel_Proc is the name of the symbol representing the
+            --  host-side procedure of the kernel. The address is
             --  pragma-imported and then used while registering the kernel with
             --  the CUDA runtime.
             Kernel_String : constant String := Kernel_Symbol & "_String";
@@ -1315,40 +1319,80 @@  package body Bindgen is
 
          begin
             --  Import host-side kernel address.
-            WBI ("   " & Kernel_Addr & " : constant System.Address;");
+            WBI ("   procedure " & Kernel_Proc & ";");
             WBI ("   pragma Import");
             WBI ("      (Convention    => C,");
-            WBI ("       Entity        => " & Kernel_Addr & ",");
+            WBI ("       Entity        => " & Kernel_Proc & ",");
             WBI ("       External_Name => """ & Kernel_Name & """);");
             WBI ("");
 
             --  Generate C-string containing name of kernel.
             WBI
-              ("   " & Kernel_String & " : Interfaces.C.Strings.Chars_Ptr :=");
-            WBI ("    Interfaces.C.Strings.New_Char_Array ("""
-                  & Kernel_Name
-                  & """);");
+              ("   " & Kernel_String & " : Interfaces.C.Strings.Chars_Ptr;");
             WBI ("");
 
+         end;
+      end loop;
+
+      WBI ("");
+   end Gen_CUDA_Defs;
+
+   -------------------
+   -- Gen_CUDA_Init --
+   -------------------
+
+   procedure Gen_CUDA_Init is
+   begin
+      if not Enable_CUDA_Expansion then
+         return;
+      end if;
+
+      WBI ("      Fat_Binary_Handle :=");
+      WBI ("        CUDA_Register_Fat_Binary (Wrapper'Address);");
+
+      for K in CUDA_Kernels.First .. CUDA_Kernels.Last loop
+         declare
+            K_String : constant String := CUDA_Kernel_Id'Image (K);
+            N : constant String :=
+              K_String (K_String'First + 1 .. K_String'Last);
+            Kernel_Symbol : constant String := "Kernel_" & N;
+            --  K_Symbol is a unique identifier used to derive all symbol names
+            --  related to kernel K.
+
+            Kernel_Proc : constant String := Kernel_Symbol & "_Proc";
+            --  Kernel_Proc is the name of the symbol representing the
+            --  host-side procedure of the kernel. The address is
+            --  pragma-imported and then used while registering the kernel with
+            --  the CUDA runtime.
+            Kernel_String : constant String := Kernel_Symbol & "_String";
+            --  Kernel_String is the name of the C-string containing the name
+            --  of the kernel. It is used for registering the kernel with the
+            --  CUDA runtime.
+            Kernel_Name : constant String :=
+               Get_Name_String (CUDA_Kernels.Table (K).Kernel_Name);
+            --  Kernel_Name is the name of the kernel, after package expansion.
+         begin
+            WBI ("      " & Kernel_String & " :=");
+            WBI ("        Interfaces.C.Strings.New_Char_Array ("""
+                  & Kernel_Name
+                  & """);");
             --  Generate call to CUDA runtime to register function.
-            WBI ("   CUDA_Register" & N & " : Boolean :=");
-            WBI ("     CUDA_Register_Function (");
-            WBI ("       Fat_Binary_Handle, ");
-            WBI ("       " & Kernel_Addr & ",");
-            WBI ("       " & Kernel_String & ",");
-            WBI ("       " & Kernel_String & ",");
-            WBI ("       -1,");
-            WBI ("       System.Null_Address,");
-            WBI ("       System.Null_Address,");
-            WBI ("       System.Null_Address,");
-            WBI ("       System.Null_Address,");
-            WBI ("       System.Null_Address);");
+            WBI ("      CUDA_Register_Function (");
+            WBI ("        Fat_Binary_Handle, ");
+            WBI ("        " & Kernel_Proc & "'Address,");
+            WBI ("        " & Kernel_String & ",");
+            WBI ("        " & Kernel_String & ",");
+            WBI ("        -1,");
+            WBI ("        System.Null_Address,");
+            WBI ("        System.Null_Address,");
+            WBI ("        System.Null_Address,");
+            WBI ("        System.Null_Address,");
+            WBI ("        System.Null_Address);");
             WBI ("");
          end;
       end loop;
 
-      WBI ("   CUDA_End : Boolean := ");
-      WBI ("      CUDA_Register_Fat_Binary_End(Fat_Binary_Handle);");
+      WBI ("      CUDA_Register_Fat_Binary_End (Fat_Binary_Handle);");
    end Gen_CUDA_Init;
 
    --------------------------
@@ -2619,7 +2663,7 @@  package body Bindgen is
            Get_Main_Name & """);");
       end if;
 
-      Gen_CUDA_Init;
+      Gen_CUDA_Defs;
 
       --  Generate version numbers for units, only if needed. Be very safe on
       --  the condition.