diff mbox

Fix Fortran deviceptr clause.

Message ID 567169ED.4040204@codesourcery.com
State New
Headers show

Commit Message

James Norris Dec. 16, 2015, 1:41 p.m. UTC
Hi,

This is an update of my previous patch. Cesar (thanks!)
pointed out some issues with the original patch that
have now been addressed.

Regtested on x86_64

OK for trunk?

Thanks!
Jim
diff mbox

Patch

diff --git a/gcc/fortran/openmp.c b/gcc/fortran/openmp.c
index 276f2f1..9350dc4 100644
--- a/gcc/fortran/openmp.c
+++ b/gcc/fortran/openmp.c
@@ -812,19 +812,10 @@  gfc_match_omp_clauses (gfc_omp_clauses **cp, uint64_t mask,
 				       OMP_MAP_ALLOC))
 	continue;
       if ((mask & OMP_CLAUSE_DEVICEPTR)
-	  && gfc_match ("deviceptr ( ") == MATCH_YES)
-	{
-	  gfc_omp_namelist **list = &c->lists[OMP_LIST_MAP];
-	  gfc_omp_namelist **head = NULL;
-	  if (gfc_match_omp_variable_list ("", list, true, NULL, &head, false)
-	      == MATCH_YES)
-	    {
-	      gfc_omp_namelist *n;
-	      for (n = *head; n; n = n->next)
-		n->u.map_op = OMP_MAP_FORCE_DEVICEPTR;
-	      continue;
-	    }
-	}
+	  && gfc_match ("deviceptr ( ") == MATCH_YES
+	  && gfc_match_omp_map_clause (&c->lists[OMP_LIST_MAP],
+				       OMP_MAP_FORCE_DEVICEPTR))
+	continue;
       if ((mask & OMP_CLAUSE_USE_DEVICE)
 	  && gfc_match_omp_variable_list ("use_device (",
 					  &c->lists[OMP_LIST_USE_DEVICE], true)
diff --git a/libgomp/oacc-parallel.c b/libgomp/oacc-parallel.c
index db7cab3..98982c3 100644
--- a/libgomp/oacc-parallel.c
+++ b/libgomp/oacc-parallel.c
@@ -49,6 +49,51 @@  find_pset (int pos, size_t mapnum, unsigned short *kinds)
   return kind == GOMP_MAP_TO_PSET;
 }
 
+/* Handle the mapping pair that are presented when a
+   deviceptr clause is used with Fortran.  */
+
+static void
+handle_ftn_pointers (size_t mapnum, void **hostaddrs, size_t *sizes,
+		     unsigned short *kinds)
+{
+  int i;
+
+  for (i = 0; i < mapnum; i++)
+    {
+      unsigned short kind1 = kinds[i] & 0xff;
+
+      /* Handle Fortran deviceptr clause.  */
+      if (kind1 == GOMP_MAP_FORCE_DEVICEPTR)
+	{
+	  unsigned short kind2;
+
+	  if (i < (signed)mapnum - 1)
+	    kind2 = kinds[i + 1] & 0xff;
+	  else
+	    kind2 = 0xffff;
+
+	  if (sizes[i] == sizeof (void *))
+	    continue;
+
+	  /* At this point, we're dealing with a Fortran deviceptr.
+	     If the next element is not what we're expecting, then
+	     this is an instance of where the deviceptr variable was
+	     not used within the region and the pointer was removed
+	     by the gimplifier.  */
+	  if (kind2 == GOMP_MAP_POINTER
+	      && sizes[i + 1] == 0
+	      && hostaddrs[i] == *(void **)hostaddrs[i + 1])
+	    {
+	      kinds[i+1] = kinds[i];
+	      sizes[i+1] = sizeof (void *);
+	    }
+
+	  /* Invalidate the entry.  */
+	  hostaddrs[i] = NULL;
+	}
+    }
+}
+
 static void goacc_wait (int async, int num_waits, va_list *ap);
 
 
@@ -88,6 +133,8 @@  GOACC_parallel_keyed (int device, void (*fn) (void *),
   thr = goacc_thread ();
   acc_dev = thr->dev;
 
+  handle_ftn_pointers (mapnum, hostaddrs, sizes, kinds);
+
   /* Host fallback if "if" clause is false or if the current device is set to
      the host.  */
   if (host_fallback)
@@ -172,8 +219,13 @@  GOACC_parallel_keyed (int device, void (*fn) (void *),
 
   devaddrs = gomp_alloca (sizeof (void *) * mapnum);
   for (i = 0; i < mapnum; i++)
-    devaddrs[i] = (void *) (tgt->list[i].key->tgt->tgt_start
-			    + tgt->list[i].key->tgt_offset);
+    {
+      if (tgt->list[i].key != NULL)
+	devaddrs[i] = (void *) (tgt->list[i].key->tgt->tgt_start
+				+ tgt->list[i].key->tgt_offset);
+      else
+	devaddrs[i] = NULL;
+    }
 
   acc_dev->openacc.exec_func (tgt_fn, mapnum, hostaddrs, devaddrs,
 			      async, dims, tgt);
@@ -224,6 +276,8 @@  GOACC_data_start (int device, size_t mapnum,
   struct goacc_thread *thr = goacc_thread ();
   struct gomp_device_descr *acc_dev = thr->dev;
 
+  handle_ftn_pointers (mapnum, hostaddrs, sizes, kinds);
+
   /* Host fallback or 'do nothing'.  */
   if ((acc_dev->capabilities & GOMP_OFFLOAD_CAP_SHARED_MEM)
       || host_fallback)
diff --git a/libgomp/testsuite/libgomp.oacc-fortran/declare-1.f90 b/libgomp/testsuite/libgomp.oacc-fortran/declare-1.f90
index f717d1b..2d4b707 100644
--- a/libgomp/testsuite/libgomp.oacc-fortran/declare-1.f90
+++ b/libgomp/testsuite/libgomp.oacc-fortran/declare-1.f90
@@ -1,29 +1,22 @@ 
 ! { dg-do run  { target openacc_nvidia_accel_selected } }
 
+! Tests to exercise the declare directive along with
+! the clauses: copy
+!              copyin
+!              copyout
+!              create
+!              present
+!              present_or_copy
+!              present_or_copyin
+!              present_or_copyout
+!              present_or_create
+
 module vars
   implicit none
   integer z
   !$acc declare create (z)
 end module vars
 
-subroutine subr6 (a, d)
-  implicit none
-  integer, parameter :: N = 8
-  integer :: i
-  integer :: a(N)
-  !$acc declare deviceptr (a)
-  integer :: d(N)
-
-  i = 0
-
-  !$acc parallel copy (d)
-    do i = 1, N
-      d(i) = a(i) + a(i)
-    end do
-  !$acc end parallel
-
-end subroutine
-
 subroutine subr5 (a, b, c, d)
   implicit none
   integer, parameter :: N = 8
@@ -201,15 +194,6 @@  subroutine subr0 (a, b, c, d)
     if (d(i) .ne. 13) call abort
   end do
 
-  call subr6 (a, d)
-
-  call test (a, .true.)
-  call test (d, .false.)
-
-  do i = 1, N
-    if (d(i) .ne. 16) call abort
-  end do
-
 end subroutine
 
 program main
@@ -241,8 +225,7 @@  program main
     if (a(i) .ne. 8) call abort
     if (b(i) .ne. 8) call abort
     if (c(i) .ne. 8) call abort
-    if (d(i) .ne. 16) call abort
+    if (d(i) .ne. 13) call abort
   end do
 
-
 end program
diff --git a/libgomp/testsuite/libgomp.oacc-fortran/deviceptr-1.f90 b/libgomp/testsuite/libgomp.oacc-fortran/deviceptr-1.f90
new file mode 100644
index 0000000..276a172
--- /dev/null
+++ b/libgomp/testsuite/libgomp.oacc-fortran/deviceptr-1.f90
@@ -0,0 +1,197 @@ 
+! { dg-do run }
+
+! Test the deviceptr clause with various directives
+! and in combination with other directives where
+! the deviceptr variable is implied.
+
+subroutine subr1 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc data deviceptr (a)
+
+  !$acc parallel copy (b)
+    do i = 1, N
+      a(i) = i * 2
+      b(i) = a(i)
+    end do
+  !$acc end parallel
+
+  !$acc end data
+
+end subroutine
+
+subroutine subr2 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  !$acc declare deviceptr (a)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc parallel copy (b)
+    do i = 1, N
+      a(i) = i * 4
+      b(i) = a(i)
+    end do
+  !$acc end parallel
+
+end subroutine
+
+subroutine subr3 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  !$acc declare deviceptr (a)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc kernels copy (b)
+    do i = 1, N
+      a(i) = i * 8
+      b(i) = a(i)
+    end do
+  !$acc end kernels
+
+end subroutine
+
+subroutine subr4 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc parallel deviceptr (a) copy (b)
+    do i = 1, N
+      a(i) = i * 16
+      b(i) = a(i)
+    end do
+  !$acc end parallel
+
+end subroutine
+
+subroutine subr5 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc kernels deviceptr (a) copy (b)
+    do i = 1, N
+      a(i) = i * 32
+      b(i) = a(i)
+    end do
+  !$acc end kernels
+
+end subroutine
+
+subroutine subr6 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc parallel deviceptr (a) copy (b)
+    do i = 1, N
+      b(i) = i
+    end do
+  !$acc end parallel
+
+end subroutine
+
+subroutine subr7 (a, b)
+  implicit none
+  integer, parameter :: N = 8
+  integer :: a(N)
+  integer :: b(N)
+  integer :: i = 0
+
+  !$acc data deviceptr (a)
+
+  !$acc parallel copy (b)
+    do i = 1, N
+      a(i) = i * 2
+      b(i) = a(i)
+    end do
+  !$acc end parallel
+
+  !$acc parallel copy (b)
+    do i = 1, N
+      a(i) = b(i) * 2
+      b(i) = a(i)
+    end do
+  !$acc end parallel
+
+  !$acc end data
+
+end subroutine
+
+program main
+  use iso_c_binding, only: c_ptr, c_f_pointer
+  implicit none
+  type (c_ptr) :: cp
+  integer, parameter :: N = 8
+  integer, pointer :: fp(:)
+  integer :: i = 0
+  integer :: b(N)
+
+  interface
+    function acc_malloc (s) bind (C)
+      use iso_c_binding, only: c_ptr, c_size_t
+      integer (c_size_t), value :: s
+      type (c_ptr) :: acc_malloc
+    end function
+  end interface
+
+  cp = acc_malloc (N * sizeof (fp(N)))
+  call c_f_pointer (cp, fp, [N])
+
+  call subr1 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i * 2) call abort
+  end do
+
+  call subr2 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i * 4) call abort
+  end do
+
+  call subr3 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i * 8) call abort
+  end do
+
+  call subr4 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i * 16) call abort
+  end do
+
+  call subr5 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i * 32) call abort
+  end do
+
+  call subr6 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i) call abort
+  end do
+
+  call subr7 (fp, b)
+
+  do i = 1, N
+    if (b(i) .ne. i * 4) call abort
+  end do
+
+end program main