diff mbox

[HSA] introduce hsa_num_threads

Message ID 560558A4.2030007@suse.cz
State New
Headers show

Commit Message

Martin Liška Sept. 25, 2015, 2:22 p.m. UTC
Hello.

In the following patch HSA is capable of handling various OMP builtins
that are utilized to set or get the number of threads.

Martin
diff mbox

Patch

From adfd806108dc5f9343811171de62b3af1d4ef903 Mon Sep 17 00:00:00 2001
From: marxin <mliska@suse.cz>
Date: Thu, 24 Sep 2015 23:07:14 +0200
Subject: [PATCH] HSA: introduce hsa_num_threads.

gcc/ChangeLog:

2015-09-25  Martin Liska  <mliska@suse.cz>

	* hsa-brig.c (emit_directive_variable): Add support
	for global scope.
	(hsa_brig_emit_omp_symbols): New function.
	* hsa-gen.c (hsa_get_string_cst_symbol): Use the newly added
	global scope flag.
	(gen_get_num_threads): Likewise
	(gen_set_num_threads): Likewise
	(gen_get_num_teams): Likewise
	(gen_get_team_num): Likewise
	(gen_hsa_insns_for_known_library_call): Add new OMP functions.
	(gen_hsa_insns_for_kernel_call): Set grid_size_x and
	workgroup_size_x to hsa_num_threads.
	(gen_hsa_insns_for_call): Handle new OMP builtins.
	(init_omp_in_prologue): New function.
	(gen_body_from_gimple): Emit OMP prologue.
	(emit_hsa_module_variables): New function.
	(generate_hsa): Emit module variables.
	* hsa.c (hsa_num_threads): New global variable.
	* hsa.h (struct hsa_symbol): Declare the variable.
---
 gcc/hsa-brig.c |  11 ++-
 gcc/hsa-gen.c  | 216 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---
 gcc/hsa.c      |   3 +
 gcc/hsa.h      |  14 +++-
 4 files changed, 230 insertions(+), 14 deletions(-)

diff --git a/gcc/hsa-brig.c b/gcc/hsa-brig.c
index 36911be..654132d 100644
--- a/gcc/hsa-brig.c
+++ b/gcc/hsa-brig.c
@@ -567,7 +567,7 @@  emit_directive_variable (struct hsa_symbol *symbol)
 		     "won't work", symbol->decl);
 	}
     }
-  else if (symbol->cst_value)
+  else if (symbol->global_scope_p)
     prefix = '&';
   else
     prefix = '%';
@@ -1923,6 +1923,15 @@  hsa_brig_emit_function (void)
   emit_queued_operands ();
 }
 
+/* Emit all OMP symbols related to OMP.  */
+
+void
+hsa_brig_emit_omp_symbols (void)
+{
+  brig_init ();
+  emit_directive_variable (hsa_num_threads);
+}
+
 /* Unit constructor and destructor statements.  */
 
 static GTY(()) tree hsa_ctor_statements;
diff --git a/gcc/hsa-gen.c b/gcc/hsa-gen.c
index 966989c..6f45bfe 100644
--- a/gcc/hsa-gen.c
+++ b/gcc/hsa-gen.c
@@ -733,6 +733,7 @@  hsa_get_string_cst_symbol (tree string_cst)
   sym->type = sym->cst_value->type;
   sym->dim = TREE_STRING_LENGTH (string_cst);
   sym->name_number = hsa_cfun->readonly_variables.length ();
+  sym->global_scope_p = true;
 
   hsa_cfun->readonly_variables.safe_push (sym);
   hsa_cfun->string_constants_map.put (string_cst, sym);
@@ -1258,8 +1259,10 @@  hsa_insn_sbr::replace_all_labels (basic_block old_bb, basic_block new_bb)
 /* Constructor of comparison instructin.  CMP is the comparison operation and T
    is the result type.  */
 
-hsa_insn_cmp::hsa_insn_cmp (BrigCompareOperation8_t cmp, BrigType16_t t)
-  : hsa_insn_basic (3 , BRIG_OPCODE_CMP, t)
+hsa_insn_cmp::hsa_insn_cmp (BrigCompareOperation8_t cmp, BrigType16_t t,
+			    hsa_op_base *arg0, hsa_op_base *arg1,
+			    hsa_op_base *arg2)
+  : hsa_insn_basic (3 , BRIG_OPCODE_CMP, t, arg0, arg1, arg2)
 {
   compare = cmp;
 }
@@ -3144,6 +3147,116 @@  gen_hsa_insns_for_return (greturn *stmt, hsa_bb *hbb,
   hbb->append_insn (ret);
 }
 
+/* Emit instructions that assign number of threads to lhs of gimple STMT.
+ Intructions are appended to basic block HBB and SSA_MAP maps gimple
+ SSA names to HSA pseudo registers.  */
+
+static void
+gen_get_num_threads (gimple *stmt, hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
+{
+  if (gimple_call_lhs (stmt) == NULL_TREE)
+    return;
+
+  hbb->append_insn (new hsa_insn_comment ("omp_get_num_threads"));
+  hsa_op_address *addr = new hsa_op_address (hsa_num_threads);
+
+  hsa_op_reg *dest = hsa_reg_for_gimple_ssa (gimple_call_lhs (stmt),
+					     ssa_map);
+  hsa_insn_basic *basic = new hsa_insn_mem
+    (BRIG_OPCODE_LD, dest->type, dest, addr);
+
+  hbb->append_insn (basic);
+}
+
+
+/* Emit instructions that set hsa_num_threads according to provided VALUE.
+ Intructions are appended to basic block HBB and SSA_MAP maps gimple
+ SSA names to HSA pseudo registers.  */
+
+static void
+gen_set_num_threads (tree value, hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
+{
+  hbb->append_insn (new hsa_insn_comment ("omp_set_num_threads"));
+  hsa_op_with_type *src = hsa_reg_or_immed_for_gimple_op (value, hbb,
+							  ssa_map);
+
+  BrigType16_t dtype = hsa_num_threads->type;
+  if (hsa_needs_cvt (dtype, src->type))
+    {
+      hsa_op_reg *tmp = new hsa_op_reg (dtype);
+      hbb->append_insn (new hsa_insn_basic (2, BRIG_OPCODE_CVT, tmp->type,
+					    tmp, src));
+      src = tmp;
+    }
+  else
+    src->type = dtype;
+
+  hsa_op_address *addr = new hsa_op_address (hsa_num_threads);
+
+  hsa_op_immed *limit = new hsa_op_immed (64, BRIG_TYPE_U32);
+  hsa_op_reg *r = new hsa_op_reg (BRIG_TYPE_B1);
+  hbb->append_insn
+    (new hsa_insn_cmp (BRIG_COMPARE_LT, r->type, r, src, limit));
+
+  BrigType16_t btype = hsa_bittype_for_type (hsa_num_threads->type);
+  hsa_op_reg *src_min_reg = new hsa_op_reg (btype);
+
+  hbb->append_insn
+    (new hsa_insn_basic (4, BRIG_OPCODE_CMOV, src_min_reg->type,
+			 src_min_reg, r, src, limit));
+
+  hsa_insn_basic *basic = new hsa_insn_mem
+    (BRIG_OPCODE_ST, hsa_num_threads->type, src_min_reg, addr);
+
+  hbb->append_insn (basic);
+}
+
+/* Emit instructions that assign number of teams to lhs of gimple STMT.
+   Intructions are appended to basic block HBB and SSA_MAP maps gimple
+   SSA names to HSA pseudo registers.  */
+
+static void
+gen_get_num_teams (gimple *stmt, hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
+{
+  if (gimple_call_lhs (stmt) == NULL_TREE)
+    return;
+
+  hbb->append_insn
+    (new hsa_insn_comment ("__builtin_omp_get_num_teams"));
+
+  tree lhs = gimple_call_lhs (stmt);
+  hsa_op_reg *dest = hsa_reg_for_gimple_ssa (lhs, ssa_map);
+  hsa_op_immed *one = new hsa_op_immed (1, dest->type);
+
+  hsa_insn_basic *basic = new hsa_insn_basic
+    (2, BRIG_OPCODE_MOV, dest->type, dest, one);
+
+  hbb->append_insn (basic);
+}
+
+/* Emit instructions that assign a team number to lhs of gimple STMT.
+   Intructions are appended to basic block HBB and SSA_MAP maps gimple
+   SSA names to HSA pseudo registers.  */
+
+static void
+gen_get_team_num (gimple *stmt, hsa_bb *hbb, vec <hsa_op_reg_p> *ssa_map)
+{
+  if (gimple_call_lhs (stmt) == NULL_TREE)
+    return;
+
+  hbb->append_insn
+    (new hsa_insn_comment ("__builtin_omp_get_team_num"));
+
+  tree lhs = gimple_call_lhs (stmt);
+  hsa_op_reg *dest = hsa_reg_for_gimple_ssa (lhs, ssa_map);
+  hsa_op_immed *zero = new hsa_op_immed (0, dest->type);
+
+  hsa_insn_basic *basic = new hsa_insn_basic
+    (2, BRIG_OPCODE_MOV, dest->type, dest, zero);
+
+  hbb->append_insn (basic);
+}
+
 /* If STMT is a call of a known library function, generate code to perform
    it and return true.  */
 
@@ -3165,6 +3278,27 @@  gen_hsa_insns_for_known_library_call (gimple *stmt, hsa_bb *hbb,
       hsa_build_append_simple_mov (dest, imm, hbb);
       return true;
     }
+  else if (strcmp (name, "omp_set_num_threads") == 0)
+    {
+      gen_set_num_threads (gimple_call_arg (stmt, 0), hbb, ssa_map);
+      return true;
+    }
+  else if (strcmp (name, "omp_get_num_threads") == 0)
+    {
+      gen_get_num_threads (stmt, hbb, ssa_map);
+      return true;
+    }
+  else if (strcmp (name, "omp_get_num_teams") == 0)
+    {
+      gen_get_num_teams (stmt, hbb, ssa_map);
+      return true;
+    }
+  else if (strcmp (name, "omp_get_team_num") == 0)
+    {
+      gen_get_team_num (stmt, hbb, ssa_map);
+      return true;
+    }
+
   return false;
 }
 
@@ -3370,21 +3504,33 @@  gen_hsa_insns_for_kernel_call (hsa_bb *hbb, gcall *call)
   hbb->append_insn (mem);
 
   /* Write to packet->grid_size_x.  */
-  hbb->append_insn (new hsa_insn_comment ("set packet->grid_size_x = 64"));
+  hbb->append_insn (new hsa_insn_comment
+		    ("set packet->grid_size_x = hsa_num_threads"));
 
   addr = new hsa_op_address (queue_packet_reg,
 			     offsetof (hsa_queue_packet, grid_size_x));
-  c = new hsa_op_immed (64, BRIG_TYPE_U16);
-  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, c, addr);
+
+  hsa_op_reg *hsa_num_threads_reg = new hsa_op_reg (hsa_num_threads->type);
+  hbb->append_insn (new hsa_insn_mem (BRIG_OPCODE_LD, hsa_num_threads->type,
+				      hsa_num_threads_reg,
+				      new hsa_op_address (hsa_num_threads)));
+
+  hsa_op_reg *threads_u16_reg = new hsa_op_reg (BRIG_TYPE_U16);
+  hbb->append_insn (new hsa_insn_basic (2, BRIG_OPCODE_CVT, BRIG_TYPE_U16,
+					threads_u16_reg, hsa_num_threads_reg));
+
+  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, threads_u16_reg,
+			  addr);
   hbb->append_insn (mem);
 
   /* Write to packet->workgroup_size_x.  */
-  hbb->append_insn (new hsa_insn_comment ("set packet->workgroup_size_x = 64"));
+  hbb->append_insn (new hsa_insn_comment
+		    ("set packet->workgroup_size_x = hsa_num_threads"));
 
   addr = new hsa_op_address (queue_packet_reg,
 			     offsetof (hsa_queue_packet, workgroup_size_x));
-  c = new hsa_op_immed (64, BRIG_TYPE_U16);
-  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, c, addr);
+  mem = new hsa_insn_mem (BRIG_OPCODE_ST, BRIG_TYPE_U16, threads_u16_reg,
+			  addr);
   hbb->append_insn (mem);
 
   /* Write to packet->grid_size_y.  */
@@ -3791,8 +3937,10 @@  gen_hsa_insns_for_call (gimple *stmt, hsa_bb *hbb,
       goto specialop;
 
     case BUILT_IN_OMP_GET_NUM_THREADS:
-      opcode = BRIG_OPCODE_GRIDSIZE;
-      goto specialop;
+      {
+	gen_get_num_threads (stmt, hbb, ssa_map);
+	break;
+      }
 
 specialop:
       {
@@ -4072,6 +4220,21 @@  specialop:
 
 	break;
       }
+    case BUILT_IN_GOMP_TEAMS:
+      {
+	gen_set_num_threads (gimple_call_arg (stmt, 1), hbb, ssa_map);
+	break;
+      }
+    case BUILT_IN_OMP_GET_NUM_TEAMS:
+      {
+	gen_get_num_teams (stmt, hbb, ssa_map);
+	break;
+      }
+    case BUILT_IN_OMP_GET_TEAM_NUM:
+      {
+	gen_get_team_num (stmt, hbb, ssa_map);
+	break;
+      }
     case BUILT_IN_MEMCPY:
       {
 	tree byte_size = gimple_call_arg (stmt, 2);
@@ -4341,6 +4504,17 @@  hsa_init_new_bb (basic_block bb)
   return new (hsa_allocp_bb) hsa_bb (bb);
 }
 
+/* Initialize OMP in an HSA basic block PROLOGUE.  */
+
+static void
+init_omp_in_prologue (hsa_bb *prologue)
+{
+  BrigType16_t t = hsa_num_threads->type;
+  prologue->append_insn
+    (new hsa_insn_mem (BRIG_OPCODE_ST, t, new hsa_op_immed (64, t),
+		       new hsa_op_address (hsa_num_threads)));
+}
+
 /* Go over gimple representation and generate our internal HSA one.  SSA_MAP
    maps gimple SSA names to HSA pseudo registers.  */
 
@@ -4380,6 +4554,8 @@  gen_body_from_gimple (vec <hsa_op_reg_p> *ssa_map)
 	}
     }
 
+  init_omp_in_prologue (hsa_bb_for_bb (ENTRY_BLOCK_PTR_FOR_FN (cfun)));
+
   FOR_EACH_BB_FN (bb, cfun)
     {
       gimple_stmt_iterator gsi;
@@ -4790,6 +4966,23 @@  convert_switch_statements ()
     }
 }
 
+/* Emit HSA module variables that are global for the entire module.  */
+
+static void
+emit_hsa_module_variables (void)
+{
+  hsa_num_threads = new hsa_symbol ();
+  memset (hsa_num_threads, 0, sizeof (hsa_symbol));
+
+  hsa_num_threads->name = "hsa_num_threads";
+  hsa_num_threads->type = BRIG_TYPE_U32;
+  hsa_num_threads->segment = BRIG_SEGMENT_PRIVATE;
+  hsa_num_threads->linkage = BRIG_LINKAGE_MODULE;
+  hsa_num_threads->global_scope_p = true;
+
+  hsa_brig_emit_omp_symbols ();
+}
+
 /* Generate HSAIL representation of the current function and write into a
    special section of the output file.  If KERNEL is set, the function will be
    considered an HSA kernel callable from the host, otherwise it will be
@@ -4798,6 +4991,9 @@  convert_switch_statements ()
 static void
 generate_hsa (bool kernel)
 {
+  if (hsa_num_threads == NULL)
+    emit_hsa_module_variables ();
+
   verify_function_arguments (cfun->decl);
   if (seen_error ())
     return;
diff --git a/gcc/hsa.c b/gcc/hsa.c
index 3cb5a5a..ce8ae45 100644
--- a/gcc/hsa.c
+++ b/gcc/hsa.c
@@ -104,6 +104,9 @@  hash_table <hsa_free_symbol_hasher> *hsa_global_variable_symbols;
 /* HSA summaries.  */
 hsa_summary_t *hsa_summaries = NULL;
 
+/* HSA number of threads.  */
+hsa_symbol *hsa_num_threads = NULL;
+
 /* True if compilation unit-wide data are already allocated and initialized.  */
 static bool compilation_unit_data_initialized;
 
diff --git a/gcc/hsa.h b/gcc/hsa.h
index 3f0d122..1382ac1 100644
--- a/gcc/hsa.h
+++ b/gcc/hsa.h
@@ -43,6 +43,9 @@  hsa_gen_requested_p (void)
 class hsa_op_immed;
 class hsa_op_cst_list;
 class hsa_insn_basic;
+class hsa_op_address;
+class hsa_op_reg;
+class hsa_bb;
 typedef hsa_insn_basic *hsa_insn_basic_p;
 
 /* Class representing an input argument, output argument (result) or a
@@ -80,6 +83,9 @@  struct hsa_symbol
 
   /* Constant value, used for string constants.  */
   hsa_op_immed *cst_value;
+
+  /* Is in global scope.  */
+  bool global_scope_p;
 };
 
 /* Abstract class for HSA instruction operands. */
@@ -446,8 +452,6 @@  is_a_helper <hsa_insn_br *>::test (hsa_insn_basic *p)
     || p->opcode == BRIG_OPCODE_CBR;
 }
 
-class hsa_bb;
-
 /* HSA instruction for swtich branche.  */
 
 class hsa_insn_sbr : public hsa_insn_basic
@@ -494,7 +498,9 @@  is_a_helper <hsa_insn_sbr *>::test (hsa_insn_basic *p)
 class hsa_insn_cmp : public hsa_insn_basic
 {
 public:
-  hsa_insn_cmp (BrigCompareOperation8_t cmp, BrigType16_t t);
+  hsa_insn_cmp (BrigCompareOperation8_t cmp, BrigType16_t t,
+		hsa_op_base *arg0 = NULL, hsa_op_base *arg1 = NULL,
+		hsa_op_base *arg2 = NULL);
 
   void *operator new (size_t);
 
@@ -1025,6 +1031,7 @@  extern struct hsa_function_representation *hsa_cfun;
 extern hash_table <hsa_free_symbol_hasher> *hsa_global_variable_symbols;
 extern hash_map <tree, vec <char *> *> *hsa_decl_kernel_dependencies;
 extern hsa_summary_t *hsa_summaries;
+extern hsa_symbol *hsa_num_threads;
 extern unsigned hsa_kernel_calls_counter;
 bool hsa_callable_function_p (tree fndecl);
 void hsa_init_compilation_unit_data (void);
@@ -1069,6 +1076,7 @@  void hsa_brig_emit_function (void);
 void hsa_output_brig (void);
 BrigType16_t bittype_for_type (BrigType16_t t);
 unsigned hsa_get_imm_brig_type_len (BrigType16_t type);
+void hsa_brig_emit_omp_symbols (void);
 
 /*  In hsa-dump.c.  */
 const char *hsa_seg_name (BrigSegment8_t);
-- 
2.5.1