Переглянути джерело

fix function struct starpu_data_interface_ops *_starpu_data_interface_get_ops(unsigned interface_id)
to return the ops also for user-registered datatypes
This allows to use starpu_data_cpy() for such datatypes

Nathalie Furmento 4 роки тому
батько
коміт
60dc3ab232

+ 42 - 0
examples/interface/complex.c

@@ -18,6 +18,25 @@
 #include "complex_interface.h"
 #include "complex_codelet.h"
 
+void copy_complex_codelet_cpu(void *descr[], void *_args)
+{
+	int i;
+	int nx = STARPU_COMPLEX_GET_NX(descr[0]);
+
+	double *i_real = STARPU_COMPLEX_GET_REAL(descr[0]);
+	double *i_imaginary = STARPU_COMPLEX_GET_IMAGINARY(descr[0]);
+
+	double *o_real = STARPU_COMPLEX_GET_REAL(descr[1]);
+	double *o_imaginary = STARPU_COMPLEX_GET_IMAGINARY(descr[1]);
+
+	for(i=0 ; i<nx ; i++)
+	{
+		o_real[i] = i_real[i];
+		o_imaginary[i] = i_imaginary[i];
+	}
+
+}
+
 static int can_execute(unsigned workerid, struct starpu_task *task, unsigned nimpl)
 {
 	(void) task;
@@ -58,6 +77,7 @@ extern void copy_complex_codelet_opencl(void *buffers[], void *args);
 
 struct starpu_codelet cl_copy =
 {
+	.cpu_funcs = {copy_complex_codelet_cpu},
 #ifdef STARPU_USE_CUDA
 	.cuda_funcs = {copy_complex_codelet_cuda},
 	.cuda_flags = {STARPU_CUDA_ASYNC},
@@ -82,6 +102,7 @@ int main(void)
 	starpu_data_handle_t handle1;
 	starpu_data_handle_t handle2;
 	starpu_data_handle_t handle3;
+	starpu_data_handle_t handle4;
 
 	double real = 45.0;
 	double imaginary = 12.0;
@@ -227,6 +248,27 @@ int main(void)
 
 	starpu_data_unpartition(handle3, STARPU_MAIN_RAM);
 
+	/* Use helper starpu_data_cpy */
+	starpu_complex_data_register(&handle4, -1, 0, 0, 1);
+	starpu_data_cpy(handle4, handle1, 0, NULL, NULL);
+	ret = starpu_task_insert(&cl_display, STARPU_VALUE, "handle4", strlen("handle4")+1, STARPU_R, handle4, 0);
+	if (ret == -ENODEV) goto end;
+	STARPU_CHECK_RETURN_VALUE(ret, "starpu_task_insert");
+	/* Compare two different complexs.  */
+	ret = starpu_task_insert(&cl_compare,
+				 STARPU_R, handle1,
+				 STARPU_R, handle4,
+				 STARPU_VALUE, &compare_ptr, sizeof(compare_ptr),
+				 0);
+	if (ret == -ENODEV) goto end;
+	STARPU_CHECK_RETURN_VALUE(ret, "starpu_task_insert");
+	starpu_task_wait_for_all();
+	if (compare != 1)
+	{
+	     FPRINTF(stderr, "Complex numbers should be similar\n");
+	     goto end;
+	}
+
 end:
 #ifdef STARPU_USE_OPENCL
 	{

+ 3 - 1
examples/interface/complex_codelet.h

@@ -87,10 +87,12 @@ void display_complex_codelet(void *descr[], void *_args)
 	if (_args)
 		starpu_codelet_unpack_args(_args, &msg);
 
+	FPRINTF(stderr, "[%s]\n", _args?msg:NULL);
 	for(i=0 ; i<nx ; i++)
 	{
-		FPRINTF(stderr, "[%s] Complex[%d] = %3.2f + %3.2f i\n", _args?msg:NULL, i, real[i], imaginary[i]);
+		FPRINTF(stderr, "\tComplex[%d] = %3.2f + %3.2f i\n", i, real[i], imaginary[i]);
 	}
+	fflush(stderr);
 }
 
 struct starpu_codelet cl_display =

+ 26 - 2
src/datawizard/interfaces/data_interface.c

@@ -30,6 +30,9 @@
 #include <util/openmp_runtime_support.h>
 #endif
 
+static struct starpu_data_interface_ops **_id_to_ops_array;
+static unsigned _id_to_ops_array_size;
+
 /* Entry in the `registered_handles' hash table.  */
 struct handle_entry
 {
@@ -50,6 +53,8 @@ static void _starpu_data_unregister(starpu_data_handle_t handle, unsigned cohere
 void _starpu_data_interface_init(void)
 {
 	_starpu_spin_init(&registered_handles_lock);
+	_id_to_ops_array_size = 20;
+	_STARPU_MALLOC(_id_to_ops_array, _id_to_ops_array_size * sizeof(struct starpu_data_interface_ops *));
 
 	/* Just for testing purpose */
 	if (starpu_get_env_number_default("STARPU_GLOBAL_ARBITER", 0) > 0)
@@ -66,6 +71,7 @@ void _starpu_data_interface_shutdown()
 	}
 
 	_starpu_spin_destroy(&registered_handles_lock);
+	free(_id_to_ops_array);
 
 	HASH_ITER(hh, registered_handles, entry, tmp)
 	{
@@ -138,8 +144,16 @@ struct starpu_data_interface_ops *_starpu_data_interface_get_ops(unsigned interf
 			return &starpu_interface_multiformat_ops;
 
 		default:
-			STARPU_ABORT();
-			return NULL;
+		{
+			if (interface_id-STARPU_MAX_INTERFACE_ID > _id_to_ops_array_size || _id_to_ops_array[interface_id-STARPU_MAX_INTERFACE_ID]==NULL)
+			{
+				_STARPU_MSG("There is no 'struct starpu_data_interface_ops' registered for interface %d\n", interface_id);
+				STARPU_ABORT();
+				return NULL;
+			}
+			else
+				return _id_to_ops_array[interface_id-STARPU_MAX_INTERFACE_ID];
+		}
 	}
 }
 
@@ -555,6 +569,16 @@ void starpu_data_register(starpu_data_handle_t *handleptr, int home_node,
 	STARPU_ASSERT(ops->register_data_handle);
 	ops->register_data_handle(handle, home_node, data_interface);
 
+	if ((unsigned)ops->interfaceid >= STARPU_MAX_INTERFACE_ID)
+	{
+		if ((unsigned)ops->interfaceid > _id_to_ops_array_size)
+		{
+			_id_to_ops_array_size *= 2;
+			_STARPU_REALLOC(_id_to_ops_array, _id_to_ops_array_size * sizeof(struct starpu_data_interface_ops *));
+		}
+		_id_to_ops_array[ops->interfaceid-STARPU_MAX_INTERFACE_ID] = ops;
+	}
+
 	_starpu_register_new_data(handle, home_node, 0);
 	_STARPU_TRACE_HANDLE_DATA_REGISTER(handle);
 }

+ 1 - 0
tests/Makefile.am

@@ -276,6 +276,7 @@ myPROGRAMS +=				\
 	datawizard/copy				\
 	datawizard/data_implicit_deps		\
 	datawizard/data_lookup			\
+	datawizard/data_register		\
 	datawizard/scratch			\
 	datawizard/scratch_reuse		\
 	datawizard/sync_and_notify_data		\

+ 99 - 0
tests/datawizard/data_register.c

@@ -0,0 +1,99 @@
+/* StarPU --- Runtime system for heterogeneous multicore architectures.
+ *
+ * Copyright (C) 2020       Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria
+ *
+ * StarPU is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * StarPU is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#include <starpu.h>
+#include "../helper.h"
+#include <datawizard/interfaces/data_interface.h>
+
+struct my_interface
+{
+	enum starpu_data_interface_id id;
+	/* Just a integer */
+	int x;
+};
+
+static struct starpu_data_interface_ops starpu_interface_my_ops;
+
+static void register_my(starpu_data_handle_t handle, unsigned home_node, void *data_interface)
+{
+	struct my_interface *my_interface = data_interface;
+	unsigned node;
+	for (node = 0; node < STARPU_MAXNODES; node++)
+	{
+		struct my_interface *local_interface = starpu_data_get_interface_on_node(handle, node);
+		local_interface->x = my_interface->x;
+		local_interface->id = my_interface->id;
+	}
+}
+
+static size_t my_get_size(starpu_data_handle_t handle)
+{
+	struct my_interface *interface = starpu_data_get_interface_on_node(handle, STARPU_MAIN_RAM);
+	return interface->x;
+}
+
+static uint32_t my_footprint(starpu_data_handle_t handle)
+{
+	return starpu_hash_crc32c_be(my_get_size(handle), 0);
+}
+
+static struct starpu_data_interface_ops starpu_interface_my_ops =
+{
+	.register_data_handle = register_my,
+	.allocate_data_on_node = NULL,
+	.free_data_on_node = NULL,
+	.copy_methods = NULL,
+	.get_size = my_get_size,
+	.get_max_size = NULL,
+	.footprint = my_footprint,
+	.compare = NULL,
+	.interfaceid = STARPU_UNKNOWN_INTERFACE_ID,
+	.interface_size = sizeof(struct my_interface),
+	.display = NULL,
+	.pack_data = NULL,
+	.unpack_data = NULL,
+	.describe = NULL,
+};
+
+#define N 42
+int main(void)
+{
+	int ret;
+	int x;
+	starpu_data_handle_t handles[N];
+
+	ret = starpu_init(NULL);
+	if (ret == -ENODEV) return STARPU_TEST_SKIPPED;
+	STARPU_CHECK_RETURN_VALUE(ret, "starpu_init");
+
+	for (x = 0; x < N; x++)
+	{
+		starpu_interface_my_ops.interfaceid = starpu_data_interface_get_next_id();
+		struct my_interface interface =
+		{
+			.id = starpu_interface_my_ops.interfaceid,
+		};
+		starpu_data_register(&handles[x], -1, &interface, &starpu_interface_my_ops);
+		STARPU_ASSERT(_starpu_data_interface_get_ops(interface.id) == &starpu_interface_my_ops);
+	}
+
+	for (x = 0; x < N; x++)
+		starpu_data_unregister(handles[x]);
+
+	starpu_shutdown();
+
+	return EXIT_SUCCESS;
+}