瀏覽代碼

- make constant types uniform

Olivier Aumage 9 年之前
父節點
當前提交
49a554275b
共有 2 個文件被更改,包括 69 次插入46 次删除
  1. 43 31
      include/fstarpu_mod.f90
  2. 26 15
      src/util/fstarpu.c

+ 43 - 31
include/fstarpu_mod.f90

@@ -15,12 +15,17 @@
 
 module fstarpu_mod
         use iso_c_binding
-
-        integer(c_int), bind(C) :: FSTARPU_R
-        integer(c_int), bind(C) :: FSTARPU_W
-        integer(c_int), bind(C) :: FSTARPU_RW
-        integer(c_int), bind(C) :: FSTARPU_SCRATCH
-        integer(c_int), bind(C) :: FSTARPU_REDUX
+        implicit none
+
+        ! Note: Constants truly are intptr_t, but are declared as c_ptr to be
+        ! readily usable in c_ptr arrays to mimic variadic functions.
+        ! A side effect, though, is that such constants cannot be logically
+        ! 'or'-ed.
+        type(c_ptr), bind(C) :: FSTARPU_R
+        type(c_ptr), bind(C) :: FSTARPU_W
+        type(c_ptr), bind(C) :: FSTARPU_RW
+        type(c_ptr), bind(C) :: FSTARPU_SCRATCH
+        type(c_ptr), bind(C) :: FSTARPU_REDUX
 
         type(c_ptr), bind(C) :: FSTARPU_DATA
         type(c_ptr), bind(C) :: FSTARPU_VALUE
@@ -144,9 +149,9 @@ module fstarpu_mod
                 end subroutine fstarpu_codelet_add_opencl_func
 
                 subroutine fstarpu_codelet_add_buffer (cl, mode) bind(C)
-                        use iso_c_binding, only: c_ptr, c_int
+                        use iso_c_binding, only: c_ptr, c_ptr
                         type(c_ptr), value, intent(in) :: cl
-                        integer(c_int), value, intent(in) :: mode
+                        type(c_ptr), value, intent(in) :: mode ! C function expects an intptr_t
                 end subroutine fstarpu_codelet_add_buffer
 
                 function fstarpu_vector_data_register(vector, nx, elt_size, ram) bind(C)
@@ -230,6 +235,19 @@ module fstarpu_mod
         end interface
 
         contains
+                function ip_to_p(i) bind(C)
+                        use iso_c_binding, only: c_ptr,c_intptr_t,C_NULL_PTR
+                        type(c_ptr) :: ip_to_p
+                        integer(c_intptr_t), value, intent(in) :: i
+                        ip_to_p = transfer(i,C_NULL_PTR)
+                end function ip_to_p
+
+                function sz_to_p(sz) bind(C)
+                        use iso_c_binding, only: c_ptr,c_size_t,c_intptr_t
+                        type(c_ptr) :: sz_to_p
+                        integer(c_size_t), value, intent(in) :: sz
+                        sz_to_p = ip_to_p(int(sz,kind=c_intptr_t))
+                end function sz_to_p
 
                 function fstarpu_init (conf) bind(C)
                         use iso_c_binding
@@ -248,39 +266,33 @@ module fstarpu_mod
 
                         interface
                                 ! These functions are not exported to the end user
-                                function fstarpu_get_integer_constant(s) bind(C)
-                                        use iso_c_binding, only: c_int,c_char
-                                        integer(c_int) :: fstarpu_get_integer_constant
+                                function fstarpu_get_constant(s) bind(C)
+                                        use iso_c_binding, only: c_ptr,c_char
+                                        type(c_ptr) :: fstarpu_get_constant ! C function returns an intptr_t
                                         character(kind=c_char) :: s
-                                end function fstarpu_get_integer_constant
-
-                                function fstarpu_get_pointer_constant(s) bind(C)
-                                        use iso_c_binding, only: c_intptr_t,c_char
-                                        integer(c_intptr_t) :: fstarpu_get_pointer_constant
-                                        character(kind=c_char) :: s
-                                end function fstarpu_get_pointer_constant
+                                end function fstarpu_get_constant
 
                                 function fstarpu_init_internal (conf) bind(C,name="starpu_init")
                                         use iso_c_binding, only: c_ptr,c_int
                                         integer(c_int) :: fstarpu_init_internal
                                         type(c_ptr), value :: conf
                                 end function fstarpu_init_internal
+
                         end interface
 
-                        ! Initialize Fortran integer constants from C peers
-                        FSTARPU_R = fstarpu_get_integer_constant(C_CHAR_"FSTARPU_R"//C_NULL_CHAR)
-                        FSTARPU_W = fstarpu_get_integer_constant(C_CHAR_"FSTARPU_W"//C_NULL_CHAR)
-                        FSTARPU_RW = fstarpu_get_integer_constant(C_CHAR_"FSTARPU_RW"//C_NULL_CHAR)
-                        FSTARPU_SCRATCH = fstarpu_get_integer_constant(C_CHAR_"FSTARPU_SCRATCH"//C_NULL_CHAR)
-                        FSTARPU_REDUX = fstarpu_get_integer_constant(C_CHAR_"FSTARPU_REDUX"//C_NULL_CHAR)
-                        ! Initialize Fortran 'pointer' constants from C peers
-                        FSTARPU_DATA = transfer(fstarpu_get_pointer_constant(C_CHAR_"FSTARPU_DATA"//C_NULL_CHAR),C_NULL_PTR)
-                        FSTARPU_VALUE = transfer(fstarpu_get_pointer_constant(C_CHAR_"FSTARPU_VALUE"//C_NULL_CHAR),C_NULL_PTR)
+                        ! Initialize Fortran constants from C peers
+                        FSTARPU_R       = fstarpu_get_constant(C_CHAR_"FSTARPU_R"//C_NULL_CHAR)
+                        FSTARPU_W       = fstarpu_get_constant(C_CHAR_"FSTARPU_W"//C_NULL_CHAR)
+                        FSTARPU_RW      = fstarpu_get_constant(C_CHAR_"FSTARPU_RW"//C_NULL_CHAR)
+                        FSTARPU_SCRATCH = fstarpu_get_constant(C_CHAR_"FSTARPU_SCRATCH"//C_NULL_CHAR)
+                        FSTARPU_REDUX   = fstarpu_get_constant(C_CHAR_"FSTARPU_REDUX"//C_NULL_CHAR)
+                        FSTARPU_DATA    = fstarpu_get_constant(C_CHAR_"FSTARPU_DATA"//C_NULL_CHAR)
+                        FSTARPU_VALUE   = fstarpu_get_constant(C_CHAR_"FSTARPU_VALUE"//C_NULL_CHAR)
                         ! Initialize size constants as 'c_ptr'
-                        FSTARPU_SZ_INT4 = transfer(int(c_sizeof(FSTARPU_SZ_INT4_dummy),kind=c_intptr_t),C_NULL_PTR)
-                        FSTARPU_SZ_INT8 = transfer(int(c_sizeof(FSTARPU_SZ_INT8_dummy),kind=c_intptr_t),C_NULL_PTR)
-                        FSTARPU_SZ_REAL4 = transfer(int(c_sizeof(FSTARPU_SZ_REAL4_dummy),kind=c_intptr_t),C_NULL_PTR)
-                        FSTARPU_SZ_REAL8 = transfer(int(c_sizeof(FSTARPU_SZ_REAL8_dummy),kind=c_intptr_t),C_NULL_PTR)
+                        FSTARPU_SZ_INT4         = sz_to_p(c_sizeof(FSTARPU_SZ_INT4_dummy))
+                        FSTARPU_SZ_INT8         = sz_to_p(c_sizeof(FSTARPU_SZ_INT8_dummy))
+                        FSTARPU_SZ_REAL4        = sz_to_p(c_sizeof(FSTARPU_SZ_REAL4_dummy))
+                        FSTARPU_SZ_REAL8        = sz_to_p(c_sizeof(FSTARPU_SZ_REAL8_dummy))
                         ! Initialize StarPU
                         if (c_associated(conf)) then 
                                 fstarpu_init = fstarpu_init_internal(conf)

+ 26 - 15
src/util/fstarpu.c

@@ -21,31 +21,28 @@
 
 #define _FSTARPU_ERROR(msg) do {fprintf(stderr, "fstarpu error: %s\n", (msg));abort();} while(0)
 
-static const int fstarpu_r	= STARPU_R;
-static const int fstarpu_w	= STARPU_W;
-static const int fstarpu_rw	= STARPU_RW;
-static const int fstarpu_scratch	= STARPU_SCRATCH;
-static const int fstarpu_redux	= STARPU_REDUX;
+static const intptr_t fstarpu_r	= STARPU_R;
+static const intptr_t fstarpu_w	= STARPU_W;
+static const intptr_t fstarpu_rw	= STARPU_RW;
+static const intptr_t fstarpu_scratch	= STARPU_SCRATCH;
+static const intptr_t fstarpu_redux	= STARPU_REDUX;
 
 static const intptr_t fstarpu_data = STARPU_R | STARPU_W | STARPU_SCRATCH | STARPU_REDUX;
 static const intptr_t fstarpu_value = STARPU_VALUE;
 
 extern void _starpu_pack_arguments(size_t *current_offset, size_t *arg_buffer_size_, char **arg_buffer_, void *ptr, size_t ptr_size);
 
-int fstarpu_get_integer_constant(char *s)
+intptr_t fstarpu_get_constant(char *s)
 {
 	if	(!strcmp(s, "FSTARPU_R"))	{ return fstarpu_r; }
 	else if	(!strcmp(s, "FSTARPU_W"))	{ return fstarpu_w; }
 	else if	(!strcmp(s, "FSTARPU_RW"))	{ return fstarpu_rw; }
 	else if	(!strcmp(s, "FSTARPU_SCRATCH"))	{ return fstarpu_scratch; }
 	else if	(!strcmp(s, "FSTARPU_REDUX"))	{ return fstarpu_redux; }
-	else { _FSTARPU_ERROR("unknown integer constant"); }
-}
 
-intptr_t fstarpu_get_pointer_constant(char *s)
-{
-	if (!strcmp(s, "FSTARPU_DATA")) { return fstarpu_data; }
-	if (!strcmp(s, "FSTARPU_VALUE")) { return fstarpu_value; }
+	else if (!strcmp(s, "FSTARPU_DATA"))	{ return fstarpu_data; }
+	else if (!strcmp(s, "FSTARPU_VALUE"))	{ return fstarpu_value; }
+
 	else { _FSTARPU_ERROR("unknown pointer constant"); }
 }
 
@@ -107,9 +104,17 @@ void fstarpu_codelet_add_opencl_func(struct starpu_codelet *cl, void *f_ptr)
 	_FSTARPU_ERROR("fstarpu: too many opencl functions in Fortran codelet");
 }
 
-void fstarpu_codelet_add_buffer(struct starpu_codelet *cl, int mode)
+void fstarpu_codelet_add_buffer(struct starpu_codelet *cl, intptr_t mode)
 {
 	const size_t max_modes = sizeof(cl->modes)/sizeof(cl->modes[0])-1;
+	if (mode !=  fstarpu_r
+		&& mode != fstarpu_rw
+		&& mode != fstarpu_w
+		&& mode != fstarpu_scratch
+		&& mode != fstarpu_redux)
+	{
+		_FSTARPU_ERROR("fstarpu: invalid data mode");
+	}
 	if  (cl->nbuffers < max_modes)
 	{
 		cl->modes[cl->nbuffers] = (unsigned int)mode;
@@ -223,7 +228,13 @@ void fstarpu_insert_task(void ***_arglist)
 	task->name = NULL;
 	while (arglist[i] != NULL)
 	{
-		if ((intptr_t)arglist[i] == fstarpu_data)
+		const intptr_t arg_type = (intptr_t)arglist[i];
+		if (arg_type == fstarpu_data
+			|| arg_type == fstarpu_r
+			|| arg_type == fstarpu_rw
+			|| arg_type == fstarpu_w
+			|| arg_type == fstarpu_scratch
+			|| arg_type == fstarpu_redux)
 		{
 			i++;
 			starpu_data_handle_t handle = arglist[i];
@@ -238,7 +249,7 @@ void fstarpu_insert_task(void ***_arglist)
 			}
 			current_buffer++;
 		}
-		else if ((intptr_t)arglist[i] == fstarpu_value)
+		else if (arg_type == fstarpu_value)
 		{
 			i++;
 			void *ptr = arglist[i];