Explorar el Código

Add function get_kernel_from_job to mpi src driver : not tested yet

Corentin Salingue hace 8 años
padre
commit
15d23e5205

+ 1 - 0
Makefile.am

@@ -75,6 +75,7 @@ versinclude_HEADERS = 				\
 	include/starpu_openmp.h			\
 	include/starpu_sink.h			\
 	include/starpu_mic.h			\
+	include/starpu_mpi_ms.h			\
 	include/starpu_scc.h			\
 	include/starpu_expert.h			\
 	include/starpu_profiling.h		\

+ 40 - 0
include/starpu_mpi_ms.h

@@ -0,0 +1,40 @@
+/* StarPU --- Runtime system for heterogeneous multicore architectures.
+ *
+ * Copyright (C) 2016  INRIA
+ *
+ * StarPU is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * StarPU is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#ifndef __STARPU_MPI_MS_H__
+#define __STARPU_MPI_MS_H__
+
+#include <starpu_config.h>
+
+#ifdef STARPU_USE_MPI_MASTER_SLAVE
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+typedef void *starpu_mpi_ms_func_symbol_t;
+
+int starpu_mpi_ms_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name);
+
+starpu_mpi_ms_kernel_t starpu_mpi_ms_get_kernel(starpu_mpi_ms_func_symbol_t symbol);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* STARPU_USE_MIC */
+#endif /* __STARPU_MIC_H__ */

+ 3 - 0
include/starpu_task.h

@@ -75,9 +75,11 @@ typedef void (*starpu_cpu_func_t)(void **, void*);
 typedef void (*starpu_cuda_func_t)(void **, void*);
 typedef void (*starpu_opencl_func_t)(void **, void*);
 typedef void (*starpu_mic_kernel_t)(void **, void*);
+typedef void (*starpu_mpi_ms_kernel_t)(void **, void*);
 typedef void (*starpu_scc_kernel_t)(void **, void*);
 
 typedef starpu_mic_kernel_t (*starpu_mic_func_t)(void);
+typedef starpu_mpi_ms_kernel_t (*starpu_mpi_ms_func_t)(void);
 typedef starpu_scc_kernel_t (*starpu_scc_func_t)(void);
 
 #define STARPU_MULTIPLE_CPU_IMPLEMENTATIONS    ((starpu_cpu_func_t) -1)
@@ -104,6 +106,7 @@ struct starpu_codelet
 	starpu_opencl_func_t opencl_funcs[STARPU_MAXIMPLEMENTATIONS];
 	char opencl_flags[STARPU_MAXIMPLEMENTATIONS];
 	starpu_mic_func_t mic_funcs[STARPU_MAXIMPLEMENTATIONS];
+	starpu_mpi_ms_func_t mpi_ms_funcs[STARPU_MAXIMPLEMENTATIONS];
 	starpu_scc_func_t scc_funcs[STARPU_MAXIMPLEMENTATIONS];
 
 	const char *cpu_funcs_name[STARPU_MAXIMPLEMENTATIONS];

+ 5 - 0
src/core/task.h

@@ -111,6 +111,11 @@ static inline starpu_mic_func_t _starpu_task_get_mic_nth_implementation(struct s
 	return cl->mic_funcs[nimpl];
 }
 
+static inline starpu_mpi_ms_func_t _starpu_task_get_mpi_ms_nth_implementation(struct starpu_codelet *cl, unsigned nimpl)
+{
+	return cl->mpi_ms_funcs[nimpl];
+}
+
 static inline starpu_scc_func_t _starpu_task_get_scc_nth_implementation(struct starpu_codelet *cl, unsigned nimpl)
 {
 	return cl->scc_funcs[nimpl];

+ 5 - 7
src/drivers/mp_common/mp_common.c

@@ -244,14 +244,12 @@ _starpu_mp_common_node_create(enum _starpu_mp_node_kind node_kind,
         node->dt_send_to_device = _starpu_mpi_common_send_to_device;
         node->dt_recv_from_device = _starpu_mpi_common_recv_from_device;
 
-/*		node->get_kernel_from_job = 
-		node->lookup = 
+		node->get_kernel_from_job = _starpu_mpi_ms_src_get_kernel_from_job;
+/*		node->lookup = 
 */		node->bind_thread = NULL;
-/*		node->execute = 
-		node->allocate = 
-		node->free = 
-
-        */
+		node->execute = NULL;
+		node->allocate = NULL;
+		node->free = NULL;
     }
 	break;
 

+ 126 - 0
src/drivers/mpi/driver_mpi_source.c

@@ -26,7 +26,25 @@
 #include <drivers/driver_common/driver_common.h>
 #include <drivers/mp_common/source_common.h>
 
+/* Mutex for concurrent access to the table.
+ */
+starpu_pthread_mutex_t htbl_mutex = STARPU_PTHREAD_MUTEX_INITIALIZER;
+
+/* Structure used by host to store informations about a kernel executable on
+ * a MPI MS device : its name, and its address on each device.
+ * If a kernel has been initialized, then a lookup has already been achieved and the
+ * device knows how to call it, else the host still needs to do a lookup.
+ */
+struct _starpu_mpi_ms_kernel
+{
+	UT_hash_handle hh;
+	char *name;
+	starpu_mpi_ms_kernel_t func[STARPU_MAXMPIDEVS];
+} *kernels;
+
 
+/* Array of structures containing all the informations useful to send
+ * and receive informations with devices */
 struct _starpu_mp_node *mpi_ms_nodes[STARPU_MAXMPIDEVS];
 
 void _starpu_mpi_source_init(struct _starpu_mp_node *node)
@@ -40,6 +58,114 @@ void _starpu_mpi_source_deinit(struct _starpu_mp_node *node)
 
 }
 
+int _starpu_mpi_ms_src_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name)
+{
+	unsigned int func_name_size = (strlen(func_name) + 1) * sizeof(char);
+
+	STARPU_PTHREAD_MUTEX_LOCK(&htbl_mutex);
+	struct _starpu_mpi_ms_kernel *kernel;
+	
+	HASH_FIND_STR(kernels, func_name, kernel);
+
+	if (kernel != NULL)
+	{
+		STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
+		// Function already in the table.
+		*symbol = kernel;
+		return 0;
+	}
+
+	kernel = malloc(sizeof(*kernel));
+	if (kernel == NULL)
+	{
+		STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
+		return -ENOMEM;
+	}
+
+	kernel->name = malloc(func_name_size);
+	if (kernel->name == NULL)
+	{
+		STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
+		free(kernel);
+		return -ENOMEM;
+	}
+
+	memcpy(kernel->name, func_name, func_name_size);
+
+	HASH_ADD_STR(kernels, name, kernel);
+
+	unsigned int nb_mpi_devices = _starpu_mpi_src_get_device_count();
+	unsigned int i;
+	for (i = 0; i < nb_mpi_devices; ++i)
+		kernel->func[i] = NULL;
+
+	STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
+
+	*symbol = kernel;
+
+	return 0;
+}
+
+
+starpu_mpi_ms_kernel_t _starpu_mpi_ms_src_get_kernel(starpu_mpi_ms_func_symbol_t symbol)
+{
+	int workerid = starpu_worker_get_id();
+	
+	/* This function has to be called in the codelet only, by the thread
+	 * which will handle the task */
+	if (workerid < 0)
+		return NULL;
+
+	int devid = starpu_worker_get_devid(workerid);
+
+	struct _starpu_mpi_ms_kernel *kernel = symbol;
+
+	if (kernel->func[devid] == NULL)
+	{
+		struct _starpu_mp_node *node = mpi_ms_nodes[devid];
+		int ret = _starpu_src_common_lookup(node, (void (**)(void))&kernel->func[devid], kernel->name);
+		if (ret)
+			return NULL;
+	}
+
+	return kernel->func[devid];
+}
+
+void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void)
+{
+	starpu_mpi_ms_kernel_t kernel = NULL;
+
+	starpu_mpi_ms_func_t func = _starpu_task_get_mpi_ms_nth_implementation(j->task->cl, j->nimpl);
+	if (func)
+	{
+		/* We execute the function contained in the codelet, it must return a
+		 * pointer to the function to execute on the device, either specified
+		 * directly by the user or by a call to starpu_mpi_ms_get_func().
+		 */
+		kernel = func();
+	}
+	else
+	{
+		/* If user dont define any starpu_mpi_ms_fun_t in cl->mpi_ms_func we try to use
+		 * cpu_func_name.
+		 */
+		char *func_name = _starpu_task_get_cpu_name_nth_implementation(j->task->cl, j->nimpl);
+		if (func_name)
+		{
+			starpu_mpi_ms_func_symbol_t symbol;
+
+			_starpu_mpi_ms_src_register_kernel(&symbol, func_name);
+
+			kernel = _starpu_mpi_ms_src_get_kernel(symbol);
+		}
+	}
+	STARPU_ASSERT(kernel);
+
+	return (void (*)(void))kernel;
+}
+
+
+
 unsigned _starpu_mpi_src_get_device_count()
 {
     int nb_mpi_devices;

+ 3 - 0
src/drivers/mpi/driver_mpi_source.h

@@ -19,6 +19,7 @@
 #define __DRIVER_MPI_SOURCE_H__
 
 #include <drivers/mp_common/mp_common.h>
+#include <starpu_mpi_ms.h>
 
 #ifdef STARPU_USE_MPI_MASTER_SLAVE
 
@@ -33,6 +34,8 @@ void _starpu_mpi_exit_useless_node(int devid);
 void _starpu_mpi_source_init(struct _starpu_mp_node *node);
 void _starpu_mpi_source_deinit(struct _starpu_mp_node *node);
 
+void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void);
+
 ///* Send *MSG which can be a command or data, to a MPI sink. */
 //extern void _starpu_mpi_source_send(const struct _starpu_mp_node *node,
 //				    void *msg, int len);