diff mbox series

OpenMP/Fortran: 'target update' with strides + DT components

Message ID 51e764f7-635f-9754-dc4b-d2cd2b58435d@codesourcery.com
State New
Headers show
Series OpenMP/Fortran: 'target update' with strides + DT components | expand

Commit Message

Tobias Burnus Oct. 31, 2022, 2:46 p.m. UTC
I recently saw that gfortran does not support derived type components
with 'target update', an OpenMP 5.0 feature.

When adding it, I also found out that strides where not handled. There
is probably some room of improvement about what to copy and what not,
but copying too much should be fine.

Build + (reg)tested on x86_64-gnu-linux without offloading configured
+ libgomp tested on x86_64-gnu-linux with nvptx offloading.
OK for mainline?

  * * *

PS: Follow-up work items:
* Strides: OpenMP seemingly permits also 'a%b([1,6,19,12])' as
   long as the first index has the lowest address. – And also
   'a%b(:)%c' is permitted – both not handled in this patch
   (and rejected with a compile-time error)
* There seems to be some problems with 'alloc' with pointers
   and allocatables in components – but I have not rechecked.
* For allocatables, 'target update' needs to do a deep mapping;
   I need to check whether that's the case.
Note for the last two: allocatable components only works OG11/OG12
and I urgently need to cleanup + (re)submit that patch to mainline.
(It came too late for GCC 12.)

* There might be also some issue mapping/refcounting, which I have not
   investigated - affecting the 'target exit data' of target-11.f90.

PPS: I intent to file at least one/some PRs about those issues, unless
I can fix them quickly.
-----------------
Siemens Electronic Design Automation GmbH; Anschrift: Arnulfstraße 201, 80634 München; Gesellschaft mit beschränkter Haftung; Geschäftsführer: Thomas Heurung, Frank Thürauf; Sitz der Gesellschaft: München; Registergericht München, HRB 106955

Comments

Jakub Jelinek Nov. 3, 2022, 12:44 p.m. UTC | #1
On Mon, Oct 31, 2022 at 03:46:25PM +0100, Tobias Burnus wrote:
> OpenMP/Fortran: 'target update' with strides + DT components
> 
> OpenMP 5.0 permits to use arrays with strides and derived
> type components for the list items to the 'from'/'to' clauses
> of the 'target update' directive.
> 
> gcc/fortran/ChangeLog:
> 
> 	* openmp.cc (gfc_match_omp_clauses): Permit derived types.
> 	(resolve_omp_clauses):Accept noncontiguous
> 	arrays.

Formatting.  Missing space before Accept and arrays. could fit on the
same line.

> 	* trans-openmp.cc (gfc_trans_omp_clauses): Fixes for
> 	derived-type changes; fix size for scalars.
> 
> libgomp/ChangeLog:
> 
> 	* testsuite/libgomp.fortran/target-11.f90: New test.
> 	* testsuite/libgomp.fortran/target-13.f90: New test.

Otherwise LGTM, assuming it actually works correctly.

I don't remember support for non-contiguous copying to/from devices
being actually added, on the library side we certainly have
omp_target_memcpy_rect which under the hood just does multiple copies
of the contiguous subparts, but I don't remember something similar
done in GOMP_target_update.  And I think it is not ok to copy bytes
that aren't requested to be copied.

	Jakub
diff mbox series

Patch

OpenMP/Fortran: 'target update' with strides + DT components

OpenMP 5.0 permits to use arrays with strides and derived
type components for the list items to the 'from'/'to' clauses
of the 'target update' directive.

gcc/fortran/ChangeLog:

	* openmp.cc (gfc_match_omp_clauses): Permit derived types.
	(resolve_omp_clauses):Accept noncontiguous
	arrays.
	* trans-openmp.cc (gfc_trans_omp_clauses): Fixes for
	derived-type changes; fix size for scalars.

libgomp/ChangeLog:

	* testsuite/libgomp.fortran/target-11.f90: New test.
	* testsuite/libgomp.fortran/target-13.f90: New test.

 gcc/fortran/openmp.cc                           |  19 ++-
 gcc/fortran/trans-openmp.cc                     |   9 +-
 libgomp/testsuite/libgomp.fortran/target-11.f90 |  75 +++++++++++
 libgomp/testsuite/libgomp.fortran/target-13.f90 | 162 ++++++++++++++++++++++++
 4 files changed, 256 insertions(+), 9 deletions(-)

diff --git a/gcc/fortran/openmp.cc b/gcc/fortran/openmp.cc
index 653c43f79ff..2daed74be72 100644
--- a/gcc/fortran/openmp.cc
+++ b/gcc/fortran/openmp.cc
@@ -2499,9 +2499,10 @@  gfc_match_omp_clauses (gfc_omp_clauses **cp, const omp_mask mask,
 					      true) == MATCH_YES)
 	    continue;
 	  if ((mask & OMP_CLAUSE_FROM)
-	      && gfc_match_omp_variable_list ("from (",
+	      && (gfc_match_omp_variable_list ("from (",
 					      &c->lists[OMP_LIST_FROM], false,
-					      NULL, &head, true) == MATCH_YES)
+					      NULL, &head, true, true)
+		  == MATCH_YES))
 	    continue;
 	  break;
 	case 'g':
@@ -3436,9 +3437,10 @@  gfc_match_omp_clauses (gfc_omp_clauses **cp, const omp_mask mask,
 		continue;
 	    }
 	  else if ((mask & OMP_CLAUSE_TO)
-	      && gfc_match_omp_variable_list ("to (",
+	      && (gfc_match_omp_variable_list ("to (",
 					      &c->lists[OMP_LIST_TO], false,
-					      NULL, &head, true) == MATCH_YES)
+					      NULL, &head, true, true)
+		  == MATCH_YES))
 	    continue;
 	  break;
 	case 'u':
@@ -7585,8 +7587,11 @@  resolve_omp_clauses (gfc_code *code, gfc_omp_clauses *omp_clauses,
 			   Only raise an error here if we're really sure the
 			   array isn't contiguous.  An expression such as
 			   arr(-n:n,-n:n) could be contiguous even if it looks
-			   like it may not be.  */
+			   like it may not be.
+			   And OpenMP's 'target update' permits strides for
+			   the to/from clause. */
 			if (code->op != EXEC_OACC_UPDATE
+			    && code->op != EXEC_OMP_TARGET_UPDATE
 			    && list != OMP_LIST_CACHE
 			    && list != OMP_LIST_DEPEND
 			    && !gfc_is_simply_contiguous (n->expr, false, true)
@@ -7630,7 +7635,9 @@  resolve_omp_clauses (gfc_code *code, gfc_omp_clauses *omp_clauses,
 			int i;
 			gfc_array_ref *ar = &lastslice->u.ar;
 			for (i = 0; i < ar->dimen; i++)
-			  if (ar->stride[i] && code->op != EXEC_OACC_UPDATE)
+			  if (ar->stride[i]
+			      && code->op != EXEC_OACC_UPDATE
+			      && code->op != EXEC_OMP_TARGET_UPDATE)
 			    {
 			      gfc_error ("Stride should not be specified for "
 					 "array section in %s clause at %L",
diff --git a/gcc/fortran/trans-openmp.cc b/gcc/fortran/trans-openmp.cc
index 9bd4e6c7e1b..4bfdf85cd9b 100644
--- a/gcc/fortran/trans-openmp.cc
+++ b/gcc/fortran/trans-openmp.cc
@@ -3626,7 +3626,10 @@  gfc_trans_omp_clauses (stmtblock_t *block, gfc_omp_clauses *clauses,
 		  gcc_unreachable ();
 		}
 	      tree node = build_omp_clause (input_location, clause_code);
-	      if (n->expr == NULL || n->expr->ref->u.ar.type == AR_FULL)
+	      if (n->expr == NULL
+		  || (n->expr->ref->type == REF_ARRAY
+		      && n->expr->ref->u.ar.type == AR_FULL
+		      && n->expr->ref->next == NULL))
 		{
 		  tree decl = gfc_trans_omp_variable (n->sym, false);
 		  if (gfc_omp_privatize_by_reference (decl))
@@ -3666,13 +3669,13 @@  gfc_trans_omp_clauses (stmtblock_t *block, gfc_omp_clauses *clauses,
 		{
 		  tree ptr;
 		  gfc_init_se (&se, NULL);
-		  if (n->expr->ref->u.ar.type == AR_ELEMENT)
+		  if (n->expr->rank == 0)
 		    {
 		      gfc_conv_expr_reference (&se, n->expr);
 		      ptr = se.expr;
 		      gfc_add_block_to_block (block, &se.pre);
 		      OMP_CLAUSE_SIZE (node)
-			= TYPE_SIZE_UNIT (TREE_TYPE (ptr));
+			= TYPE_SIZE_UNIT (TREE_TYPE (TREE_TYPE (ptr)));
 		    }
 		  else
 		    {
diff --git a/libgomp/testsuite/libgomp.fortran/target-11.f90 b/libgomp/testsuite/libgomp.fortran/target-11.f90
new file mode 100644
index 00000000000..b0faa2e620d
--- /dev/null
+++ b/libgomp/testsuite/libgomp.fortran/target-11.f90
@@ -0,0 +1,75 @@ 
+! Based on libgomp.c/target-23.c
+
+! { dg-additional-options "-fdump-tree-original" }
+! { dg-final { scan-tree-dump "omp target update to\\(xxs\\\[3\\\] \\\[len: 2\\\]\\)" "original" } }
+! { dg-final { scan-tree-dump "omp target update to\\(s\\.s \\\[len: 4\\\]\\)" "original" } }
+! { dg-final { scan-tree-dump "omp target update from\\(s\\.s \\\[len: 4\\\]\\)" "original" } }
+
+module m
+  implicit none
+  type S_type
+    integer s
+    integer, pointer :: u(:) => null()
+    integer :: v(0:4)
+  end type S_type
+  integer, volatile :: z
+end module m
+
+program main
+  use m
+  implicit none
+  integer, target :: u(0:9) = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+  logical :: err
+  type (S_type) :: s
+  integer, pointer :: v(:)
+  integer(kind=2) :: xxs(5)
+  err = .false.
+  s = S_type(9, v=[10, 11, 12, 13, 14])
+  s%u(0:) => u(3:)
+  v(-4+3:) => u(3:)
+  xxs = [-1,-2,-3,-4,-5]
+  !$omp target enter data map (to: s%s, s%u, s%u(0:5)) map (alloc: s%v(1:4), xxs(3:5))
+  s%s = s%s + 1
+  u(3) = u(3) + 1
+  s%v(1) = s%v(1) + 1
+  xxs(3) = -33
+  xxs(4) = -44
+  xxs(5) = -55
+  !$omp target update to (xxs(4))
+  !$omp target update to (s%s) to (s%u(0:2), s%v(1:4))
+
+  !$omp target map (alloc: s%s, s%v(1:4)) map (from: err)
+    err = .false.
+    if (s%s /= 10 .or. s%v(1) /= 12 .or. s%v(2) /= 12 .or. s%v(3) /= 13) &
+      err = .true.
+    if (v(-1) /= 4 .or. v(0) /= 4 .or. v(1) /= 5 .or. v(2) /= 6 .or. v(3) /= 7) &
+      err = .true.
+    if (xxs(4) /= -44) &
+      err = .true.
+    s%s = s%s + 1
+    s%v(2) = s%v(2) + 2
+    v(-1) = 5
+    v(3) = 9
+  !$omp end target
+
+  if (err) &
+    error stop
+
+  !$omp target map (alloc: s%u(0:5))
+    err = .false.
+    if (s%u(0) /= 5 .or. s%u(1) /= 4 .or. s%u(2) /= 5 .or. s%u(3) /= 6 .or. s%u(4) /= 9) &
+      err = .true.
+    s%u(1) = 12
+  !$omp end target
+
+  !$omp target update from (s%s, s%u(0:5)) from (s%v(1:4))
+  if (err .or. s%s /= 11 .or. u(0) /= 0 .or. u(1) /= 1 .or. u(2) /= 2 .or. u(3) /= 5 &
+      .or. u(4) /= 12 .or. u(5) /= 5 .or. u(6) /= 6 .or. u(7) /= 9 .or. u(8) /= 8    &
+      .or. u(9) /= 9 .or. s%v(0) /= 10 .or. s%v(1) /= 12 .or. s%v(2) /= 14           &
+      .or. s%v(3) /= 13 .or. s%v(4) /= 14)                                           &
+    error stop
+  ! !$omp target exit data map (release: s%s)
+  ! !$omp target exit data map (release: s%u(0:5))
+  ! !$omp target exit data map (delete: s%v(1:4))
+  ! !$omp target exit data map (release: s%s)
+end
diff --git a/libgomp/testsuite/libgomp.fortran/target-13.f90 b/libgomp/testsuite/libgomp.fortran/target-13.f90
new file mode 100644
index 00000000000..e6334a5275f
--- /dev/null
+++ b/libgomp/testsuite/libgomp.fortran/target-13.f90
@@ -0,0 +1,162 @@ 
+module m
+  implicit none
+  type t
+    integer :: s, a(5)
+  end type t
+
+  type t2
+    integer :: s, a(5)
+    type(t) :: st, at(2:3)
+  end type t2
+
+  interface operator(/=)
+    procedure ne_compare_t
+    procedure ne_compare_t2
+  end interface
+
+contains
+
+  logical pure elemental function ne_compare_t (a, b) result(res)
+    type(t), intent(in) :: a, b
+    res = (a%s /= b%s) .or. any(a%a /= b%a)
+  end function
+
+  logical pure elemental function ne_compare_t2 (a, b) result(res)
+    type(t2), intent(in) :: a, b
+    res = (a%s /= b%s) .or. any(a%a /= b%a)     &
+          .or. (a%st /= b%st) .or. any(a%at /= b%at)
+  end function
+end module m
+
+program p
+use m
+implicit none
+
+type(t2) :: var1, var2(5), var3(:)
+type(t2) :: var1a, var2a(5), var3a(:)
+allocatable :: var3, var3a
+logical :: shared_memory = .false.
+
+!$omp target map(to: shared_memory)
+  shared_memory = .true.
+!$omp end target
+
+var1 = T2(1, [1,2,3,4,5], T(11, [11,22,33,44,55]), &
+          [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])])
+
+var2 = [T2(101, [201,202,203,204,205], T(2011, [2011,2022,2033,2044,2055]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(111, [211,212,213,214,215], T(2111, [2111,2122,2133,2144,2155]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(121, [221,222,223,224,225], T(2211, [2211,2222,2233,2244,2255]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(131, [231,232,233,234,235], T(2311, [2311,2322,2333,2344,2355]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(141, [241,242,243,244,245], T(2411, [2411,2422,2433,2444,2455]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])])]
+
+var3 = [T2(301, [401,402,403,404,405], T(4011, [4011,4022,4033,4044,4055]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(311, [411,412,413,414,415], T(4111, [4111,4122,4133,4144,4155]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(321, [421,422,423,424,425], T(4211, [4211,4222,4233,4244,4255]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(331, [431,432,433,434,435], T(4311, [4311,4322,4333,4344,4355]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])]),       &
+        T2(341, [441,442,443,444,445], T(4411, [4411,4422,4433,4444,4455]), &
+           [T(-11, [-11,-22,-33,-44,-55]), T(11, [11,22,33,44,55])])]
+
+var1a = var1
+var2a = var2
+var3a = var3
+
+!$omp target enter data map(to:var1)
+!$omp target enter data map(to:var2)
+!$omp target enter data map(to:var3)
+
+! ---------------
+
+!$omp target update from(var1%at(::2))
+
+if (var1a /= var1) error stop
+if (any (var2a /= var2)) error stop
+if (any (var3a /= var3)) error stop
+
+! ---------------
+
+!$omp target
+  var1%st%s = 1243
+  var2(3)%at(2) = T(123, [345,64,356,39,13])
+  var2(3)%at(3) = T(48, [74,162,572,357,3])
+!$omp end target
+
+if (.not. shared_memory) then
+  if (var1 /= var1) error stop
+  if (any (var2a /= var2)) error stop
+  if (any (var3a /= var3)) error stop
+endif
+
+!$omp target update from(var1%st) from(var2(3)%at(2:3))
+
+var1a%st%s = 1243
+var2a(3)%at(2) = T(123, [345,64,356,39,13])
+var2a(3)%at(3) = T(48, [74,162,572,357,3])
+if (var1 /= var1) error stop
+if (any (var2a /= var2)) error stop
+if (any (var3a /= var3)) error stop
+
+! ---------------
+
+var3(1) = var2(1)
+var1%at(2)%a = var2(1)%a
+var1%at(3)%a = var2(2)%a
+
+var1a = var1
+var2a = var2
+var3a = var3
+
+!$omp target update to(var3) to(var1%at(2:3))
+
+!$omp target
+  var3(1)%s = var3(1)%s + 123
+  var1%at(2)%a = var1%at(2)%a * 7
+  var1%at(3)%s = var1%at(3)%s * (-3)
+!$omp end target
+
+if (.not. shared_memory) then
+  if (var1 /= var1) error stop
+  if (any (var2a /= var2)) error stop
+  if (any (var3a /= var3)) error stop
+endif
+
+var3a(1)%s = var3a(1)%s + 123
+var1a%at(2)%a = var1a%at(2)%a * 7
+var1a%at(3)%s = var1a%at(3)%s * (-3)
+
+block
+  integer, volatile :: i1,i2,i3,i4,i5,i6
+  i1 = 1
+  i2 = 2
+  i3 = 1
+  i4 = 1
+  i5 = 2
+  i6 = 1
+  !$omp target update from(var3(i1:i2:i3)) from(var1%at(i4:i5:i6))
+  i1 = 3
+  i2 = 3
+  i3 = 1
+  i4 = 5
+  i5 = 1
+  !$omp target update from(var1%at(i1)%s) from(var1%at(i1)%a(i3:i4:i5))
+end block
+
+if (var1 /= var1) error stop
+if (any (var2a /= var2)) error stop
+if (any (var3a /= var3)) error stop
+
+! ---------------
+
+!$omp target exit data map(from:var1)
+!$omp target exit data map(from:var2)
+!$omp target exit data map(from:var3)
+end