diff mbox

[HSA] Fix emission of hsa_num_threads

Message ID 561D192F.3080307@suse.cz
State New
Headers show

Commit Message

Martin Liška Oct. 13, 2015, 2:46 p.m. UTC
Hello.

Following pair of patches changes behavior of omp_{get,set}_num_threads and
provides more clever way how these values are passed to a another kernel.

Martin
diff mbox

Patch

From 7f10daa1f37ee47091a3956a13bb610464e8e279 Mon Sep 17 00:00:00 2001
From: marxin <mliska@suse.cz>
Date: Mon, 12 Oct 2015 15:49:50 +0200
Subject: [PATCH 2/2] HSA: handle properly number of threads in a kernel

gcc/ChangeLog:

2015-10-13  Martin Liska  <mliska@suse.cz>

	* hsa-gen.c (hsa_insn_basic::set_output_in_type): New function.
	(query_hsa_grid): Likewise.
	(gen_set_num_threads): Save the value without any value range
	checking.
	(gen_num_threads_for_dispatch): New function.
	(gen_hsa_insns_for_known_library_call): Use the newly added
	function query_hsa_grid.
	(gen_hsa_insns_for_call): Likewise.
	(gen_hsa_insns_for_kernel_call): Use the newly added function
	gen_num_threads_for_dispatch.
	(init_omp_in_prologue): Initialize hsa_num_threads to 0.
	(init_prologue): New function.
	(init_hsa_num_threads): Likewise.
	* hsa.h: Declare a new function.
---
 gcc/hsa-gen.c | 224 ++++++++++++++++++++++++++++++++++++----------------------
 gcc/hsa.h     |   1 +
 2 files changed, 141 insertions(+), 84 deletions(-)

diff --git a/gcc/hsa-gen.c b/gcc/hsa-gen.c
index ab4917b..e64f4c6 100644
--- a/gcc/hsa-gen.c
+++ b/gcc/hsa-gen.c
@@ -105,6 +105,10 @@  along with GCC; see the file COPYING3.  If not see
   } \
   while (false);
 
+/* Default number of threads used by kernel dispatch.  */
+
+#define HSA_DEFAULT_NUM_THREADS 64
+
 /* Following structures are defined in the final version
    of HSA specification.  */
 
@@ -3238,27 +3242,67 @@  gen_hsa_insns_for_return (greturn *stmt, hsa_bb *hbb,
   hbb->append_insn (ret);
 }
 
-/* Emit instructions that assign number of threads to lhs of gimple STMT.
- Intructions are appended to basic block HBB and SSA_MAP maps gimple
- SSA names to HSA pseudo registers.  */
+/* Set OP_INDEX-th operand of the instruction to DEST, as the DEST
+   can have a different type, conversion instructions are possibly
+   appended to HBB.  */
 
-static void
-gen_get_num_threads (gimple *stmt, hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
+void
+hsa_insn_basic::set_output_in_type (hsa_op_reg *dest, unsigned op_index,
+				    hsa_bb *hbb)
 {
-  if (gimple_call_lhs (stmt) == NULL_TREE)
-    return;
+  hsa_insn_basic *insn;
+  gcc_checking_assert (hsa_opcode_op_output_p (opcode, op_index));
 
-  hbb->append_insn (new hsa_insn_comment ("omp_get_num_threads"));
-  hsa_op_address *addr = new hsa_op_address (hsa_num_threads);
+  if (dest->type == type)
+    set_op (op_index, dest);
 
-  hsa_op_reg *dest = hsa_reg_for_gimple_ssa (gimple_call_lhs (stmt),
-					     ssa_map);
-  hsa_insn_basic *basic = new hsa_insn_mem
-    (BRIG_OPCODE_LD, dest->type, dest, addr);
+  hsa_op_reg *tmp = new hsa_op_reg (type);
+  set_op (op_index, tmp);
 
-  hbb->append_insn (basic);
+  if (hsa_needs_cvt (dest->type, type))
+    insn = new hsa_insn_basic (2, BRIG_OPCODE_CVT, dest->type,
+			       dest, tmp);
+  else
+    insn = new hsa_insn_basic (2, BRIG_OPCODE_MOV, dest->type,
+			       dest, tmp->get_in_type (dest->type, hbb));
+
+  hbb->append_insn (insn);
 }
 
+/* Generate instruction OPCODE to query a property of HSA grid along the
+   given DIMENSION.  Store result into DEST and append the instruction to
+   HBB.  */
+
+static void
+query_hsa_grid (hsa_op_reg *dest, BrigType16_t opcode, int dimension,
+		hsa_bb *hbb)
+{
+  /* We're using just one-dimensional kernels, so hard-coded
+     dimension X.  */
+  hsa_op_immed *imm = new hsa_op_immed (dimension,
+					(BrigKind16_t) BRIG_TYPE_U32);
+  hsa_insn_basic *insn = new hsa_insn_basic (2, opcode, BRIG_TYPE_U32, NULL,
+					     imm);
+  hbb->append_insn (insn);
+  insn->set_output_in_type (dest, 0, hbb);
+}
+
+/* Generate a special HSA-related instruction for gimple STMT.
+   Intructions are appended to basic block HBB and SSA_MAP maps gimple
+   SSA names to HSA pseudo registers.  */
+
+static void
+query_hsa_grid (gimple *stmt, BrigOpcode16_t opcode, int dimension,
+		hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
+{
+  tree lhs = gimple_call_lhs (dyn_cast <gcall *> (stmt));
+  if (lhs == NULL_TREE)
+    return;
+
+  hsa_op_reg *dest = hsa_reg_for_gimple_ssa (lhs, ssa_map);
+
+  query_hsa_grid (dest, opcode, dimension, hbb);
+}
 
 /* Emit instructions that set hsa_num_threads according to provided VALUE.
  Intructions are appended to basic block HBB and SSA_MAP maps gimple
@@ -3268,30 +3312,71 @@  static void
 gen_set_num_threads (tree value, hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
 {
   hbb->append_insn (new hsa_insn_comment ("omp_set_num_threads"));
-  hsa_op_with_type *src = hsa_reg_or_immed_for_gimple_op (value, hbb,
-							  ssa_map);
+  hsa_op_with_type *src = hsa_reg_or_immed_for_gimple_op (value, hbb, ssa_map);
 
   src = src->get_in_type (hsa_num_threads->type, hbb);
   hsa_op_address *addr = new hsa_op_address (hsa_num_threads);
 
-  hsa_op_immed *limit = new hsa_op_immed (64, BRIG_TYPE_U32);
+  hsa_insn_basic *basic = new hsa_insn_mem
+    (BRIG_OPCODE_ST, hsa_num_threads->type, src, addr);
+  hbb->append_insn (basic);
+}
+
+/* Return an HSA register that will contain number of threads for
+   a future dispatched kernel.  Instructions are added to HBB.  */
+
+static hsa_op_reg *
+gen_num_threads_for_dispatch (hsa_bb *hbb)
+{
+  /* Step 1) Assign to number of threads:
+     MIN (HSA_DEFAULT_NUM_THREADS, hsa_num_threads).  */
+  hsa_op_reg *threads = new hsa_op_reg (hsa_num_threads->type);
+  hsa_op_address *addr = new hsa_op_address (hsa_num_threads);
+
+  hbb->append_insn (new hsa_insn_mem (BRIG_OPCODE_LD, threads->type,
+				      threads, addr));
+
+  hsa_op_immed *limit = new hsa_op_immed (HSA_DEFAULT_NUM_THREADS,
+					  BRIG_TYPE_U32);
   hsa_op_reg *r = new hsa_op_reg (BRIG_TYPE_B1);
   hbb->append_insn
-    (new hsa_insn_cmp (BRIG_COMPARE_LT, r->type, r, src, limit));
+    (new hsa_insn_cmp (BRIG_COMPARE_LT, r->type, r, threads, limit));
 
-  BrigType16_t btype = hsa_bittype_for_type (hsa_num_threads->type);
-  hsa_op_reg *src_min_reg = new hsa_op_reg (btype);
+  BrigType16_t btype = hsa_bittype_for_type (threads->type);
+  hsa_op_reg *tmp = new hsa_op_reg (threads->type);
 
   hbb->append_insn
-    (new hsa_insn_basic (4, BRIG_OPCODE_CMOV, src_min_reg->type,
-			 src_min_reg, r, src, limit));
+    (new hsa_insn_basic (4, BRIG_OPCODE_CMOV, btype, tmp, r,
+			 threads, limit));
 
-  hsa_insn_basic *basic = new hsa_insn_mem
-    (BRIG_OPCODE_ST, hsa_num_threads->type, src_min_reg, addr);
+  /* Step 2) If the number is equal to zero,
+     return shadow->:mp_num_threads.  */
+  hsa_op_reg *shadow_reg_ptr = hsa_cfun->get_shadow_reg ();
 
+  hsa_op_reg *shadow_thread_count = new hsa_op_reg (BRIG_TYPE_U32);
+  addr = new hsa_op_address
+   (shadow_reg_ptr, offsetof (hsa_kernel_dispatch, omp_num_threads));
+  hsa_insn_basic *basic = new hsa_insn_mem
+   (BRIG_OPCODE_LD, shadow_thread_count->type, shadow_thread_count, addr);
   hbb->append_insn (basic);
+
+  hsa_op_reg *tmp2 = new hsa_op_reg (threads->type);
+  r = new hsa_op_reg (BRIG_TYPE_B1);
+  hbb->append_insn
+    (new hsa_insn_cmp (BRIG_COMPARE_EQ, r->type, r, tmp,
+		       new hsa_op_immed (0, shadow_thread_count->type)));
+  hbb->append_insn
+    (new hsa_insn_basic (4, BRIG_OPCODE_CMOV, btype, tmp2, r,
+			 shadow_thread_count, tmp));
+
+  hsa_op_reg *dest = new hsa_op_reg (BRIG_TYPE_U16);
+  hbb->append_insn (new hsa_insn_basic (2, BRIG_OPCODE_CVT, dest->type,
+					dest, tmp2));
+
+  return dest;
 }
 
+
 /* Emit instructions that assign number of teams to lhs of gimple STMT.
    Intructions are appended to basic block HBB and SSA_MAP maps gimple
    SSA names to HSA pseudo registers.  */
@@ -3381,7 +3466,7 @@  gen_hsa_insns_for_known_library_call (gimple *stmt, hsa_bb *hbb,
     }
   else if (strcmp (name, "omp_get_num_threads") == 0)
     {
-      gen_get_num_threads (stmt, hbb, ssa_map);
+      query_hsa_grid (stmt, BRIG_OPCODE_GRIDSIZE, 0, hbb, ssa_map);
       return true;
     }
   else if (strcmp (name, "omp_get_num_teams") == 0)
@@ -3606,24 +3691,17 @@  gen_hsa_insns_for_kernel_call (hsa_bb *hbb, gcall *call)
 			  addr);
   hbb->append_insn (mem);
 
-  /* Write to packet->grid_size_x.  */
+  /* Write to packet->grid_size_x.  If the default value is not changed,
+     emit passed grid_size.  */
+  hsa_op_reg *threads_reg = gen_num_threads_for_dispatch (hbb);
+
   hbb->append_insn (new hsa_insn_comment
 		    ("set packet->grid_size_x = hsa_num_threads"));
 
   addr = new hsa_op_address (queue_packet_reg,
 			     offsetof (hsa_queue_packet, grid_size_x));
 
-  hsa_op_reg *hsa_num_threads_reg = new hsa_op_reg (hsa_num_threads->type);
-  hbb->append_insn (new hsa_insn_mem (BRIG_OPCODE_LD, hsa_num_threads->type,
-				      hsa_num_threads_reg,
-				      new hsa_op_address (hsa_num_threads)));
-
-  hsa_op_reg *threads_u16_reg = new hsa_op_reg (BRIG_TYPE_U16);
-  hbb->append_insn (new hsa_insn_basic (2, BRIG_OPCODE_CVT, BRIG_TYPE_U16,
-					threads_u16_reg, hsa_num_threads_reg));
-
-  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, threads_u16_reg,
-			  addr);
+  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, threads_reg, addr);
   hbb->append_insn (mem);
 
   /* Write to shadow_reg->omp_num_threads = hsa_num_threads.  */
@@ -3633,8 +3711,7 @@  gen_hsa_insns_for_kernel_call (hsa_bb *hbb, gcall *call)
   addr = new hsa_op_address (shadow_reg, offsetof (hsa_kernel_dispatch,
 						   omp_num_threads));
   hbb->append_insn
-    (new hsa_insn_mem (BRIG_OPCODE_ST, hsa_num_threads_reg->type,
-		       hsa_num_threads_reg, addr));
+    (new hsa_insn_mem (BRIG_OPCODE_ST, threads_reg->type, threads_reg, addr));
 
   /* Write to packet->workgroup_size_x.  */
   hbb->append_insn (new hsa_insn_comment
@@ -3642,7 +3719,7 @@  gen_hsa_insns_for_kernel_call (hsa_bb *hbb, gcall *call)
 
   addr = new hsa_op_address (queue_packet_reg,
 			     offsetof (hsa_queue_packet, workgroup_size_x));
-  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, threads_u16_reg,
+  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, threads_reg,
 			  addr);
   hbb->append_insn (mem);
 
@@ -4024,8 +4101,6 @@  gen_hsa_insns_for_call (gimple *stmt, hsa_bb *hbb,
 {
   tree lhs = gimple_call_lhs (stmt);
   hsa_op_reg *dest;
-  hsa_insn_basic *insn;
-  int opcode;
 
   if (!gimple_call_builtin_p (stmt, BUILT_IN_NORMAL))
     {
@@ -4050,36 +4125,14 @@  gen_hsa_insns_for_call (gimple *stmt, hsa_bb *hbb,
   switch (DECL_FUNCTION_CODE (fndecl))
     {
     case BUILT_IN_OMP_GET_THREAD_NUM:
-      opcode = BRIG_OPCODE_WORKITEMABSID;
-      goto specialop;
-
-    case BUILT_IN_OMP_GET_NUM_THREADS:
       {
-	gen_get_num_threads (stmt, hbb, ssa_map);
+	query_hsa_grid (stmt, BRIG_OPCODE_WORKITEMABSID, 0, hbb, ssa_map);
 	break;
       }
 
-specialop:
+    case BUILT_IN_OMP_GET_NUM_THREADS:
       {
-	hsa_op_reg *tmp;
-	dest = hsa_reg_for_gimple_ssa (lhs, ssa_map);
-	/* We're using just one-dimensional kernels, so hard-coded
-	   dimension X.  */
-	hsa_op_immed *imm = new hsa_op_immed
-	  (build_zero_cst (uint32_type_node));
-	if (dest->type != BRIG_TYPE_U32)
-	  tmp = new hsa_op_reg (BRIG_TYPE_U32);
-	else
-	  tmp = dest;
-	insn = new hsa_insn_basic (2, opcode, tmp->type, tmp, imm);
-	hbb->append_insn (insn);
-	if (dest != tmp)
-	  {
-	    int opc2 = dest->type == BRIG_TYPE_S32 ? BRIG_OPCODE_MOV
-	      : BRIG_OPCODE_CVT;
-	    insn = new hsa_insn_basic (2, opc2, dest->type, dest, tmp);
-	    hbb->append_insn (insn);
-	  }
+	query_hsa_grid (stmt, BRIG_OPCODE_GRIDSIZE, 0, hbb, ssa_map);
 	break;
       }
 
@@ -4618,28 +4671,13 @@  hsa_init_new_bb (basic_block bb)
 /* Initialize OMP in an HSA basic block PROLOGUE.  */
 
 static void
-init_omp_in_prologue (void)
+init_prologue (void)
 {
   if (!hsa_cfun->kern_p)
     return;
 
   hsa_bb *prologue = hsa_bb_for_bb (ENTRY_BLOCK_PTR_FOR_FN (cfun));
 
-  /* Load a default value from shadow argument.  */
-  hsa_op_reg *shadow_reg_ptr = hsa_cfun->get_shadow_reg ();
-  hsa_op_address *addr = new hsa_op_address
-    (shadow_reg_ptr, offsetof (hsa_kernel_dispatch, omp_num_threads));
-
-  hsa_op_reg *threads = new hsa_op_reg (BRIG_TYPE_U32);
-  hsa_insn_basic *basic = new hsa_insn_mem
-    (BRIG_OPCODE_LD, threads->type, threads, addr);
-  prologue->append_insn (basic);
-
-  /* Save it to private variable hsa_num_threads.  */
-  basic = new hsa_insn_mem (BRIG_OPCODE_ST, hsa_num_threads->type, threads,
-			    new hsa_op_address (hsa_num_threads));
-  prologue->append_insn (basic);
-
   /* Create a magic number that is going to be printed by libgomp.  */
   unsigned index = hsa_get_number_decl_kernel_mappings ();
 
@@ -4648,6 +4686,21 @@  init_omp_in_prologue (void)
     set_debug_value (prologue, new hsa_op_immed (1000 + index, BRIG_TYPE_U64));
 }
 
+/* Initialize hsa_num_threads to a default value.  */
+
+static void
+init_hsa_num_threads (void)
+{
+  hsa_bb *prologue = hsa_bb_for_bb (ENTRY_BLOCK_PTR_FOR_FN (cfun));
+
+  /* Save the default value to private variable hsa_num_threads.  */
+  hsa_insn_basic *basic = new hsa_insn_mem
+    (BRIG_OPCODE_ST, hsa_num_threads->type,
+     new hsa_op_immed (0, hsa_num_threads->type),
+     new hsa_op_address (hsa_num_threads));
+  prologue->append_insn (basic);
+}
+
 /* Go over gimple representation and generate our internal HSA one.  SSA_MAP
    maps gimple SSA names to HSA pseudo registers.  */
 
@@ -5150,12 +5203,15 @@  generate_hsa (bool kernel)
   if (hsa_seen_error ())
     goto fail;
 
-  init_omp_in_prologue ();
+  init_prologue ();
 
   gen_body_from_gimple (&ssa_map);
   if (hsa_seen_error ())
     goto fail;
 
+  if (hsa_cfun->kernel_dispatch_count)
+    init_hsa_num_threads ();
+
   if (hsa_cfun->kern_p)
     {
       hsa_add_kern_decl_mapping (current_function_decl, hsa_cfun->name,
diff --git a/gcc/hsa.h b/gcc/hsa.h
index 89d339f..c7e3957 100644
--- a/gcc/hsa.h
+++ b/gcc/hsa.h
@@ -364,6 +364,7 @@  public:
   void verify ();
   unsigned input_count ();
   unsigned num_used_ops ();
+  void set_output_in_type (hsa_op_reg *dest, unsigned op_index, hsa_bb *hbb);
 
   /* The previous and next instruction in the basic block.  */
   hsa_insn_basic *prev, *next;
-- 
2.6.0