Ver código fonte

add Master <-> Slaves transfers + fix bug when asynchronous messages are received + fix bug when multiple thread is not activated

Corentin Salingue 8 anos atrás
pai
commit
4a14747dd9

+ 15 - 1
src/core/task.c

@@ -485,6 +485,18 @@ void _starpu_codelet_check_deprecated_fields(struct starpu_codelet *cl)
 
 	some_impl = 0;
 	for (i = 0; i < STARPU_MAXIMPLEMENTATIONS; i++)
+		if (cl->mpi_ms_funcs[i])
+		{
+			some_impl = 1;
+			break;
+		}
+	if (some_impl && is_where_unset)
+	{
+		cl->where |= STARPU_MPI;
+	}
+
+	some_impl = 0;
+	for (i = 0; i < STARPU_MAXIMPLEMENTATIONS; i++)
 		if (cl->scc_funcs[i])
 		{
 			some_impl = 1;
@@ -504,7 +516,7 @@ void _starpu_codelet_check_deprecated_fields(struct starpu_codelet *cl)
 		}
 	if (some_impl && is_where_unset)
 	{
-		cl->where |= STARPU_MIC|STARPU_SCC;
+		cl->where |= STARPU_MIC|STARPU_SCC|STARPU_MPI;
 	}
 }
 
@@ -1137,6 +1149,7 @@ _starpu_handle_needs_conversion_task_for_arch(starpu_data_handle_t handle,
 				case STARPU_CUDA_RAM:      /* Fall through */
 				case STARPU_OPENCL_RAM:
 				case STARPU_MIC_RAM:
+                case STARPU_MPI_MS_RAM:
 				case STARPU_SCC_RAM:
 					return 1;
 				default:
@@ -1154,6 +1167,7 @@ _starpu_handle_needs_conversion_task_for_arch(starpu_data_handle_t handle,
 				case STARPU_CUDA_RAM:
 				case STARPU_OPENCL_RAM:
 				case STARPU_MIC_RAM:
+                case STARPU_MPI_MS_RAM:
 				case STARPU_SCC_RAM:
 					return 0;
 				default:

+ 1 - 2
src/core/workers.c

@@ -605,7 +605,6 @@ void _starpu_driver_start(struct _starpu_worker *worker, unsigned fut_key, unsig
 	_starpu_fxt_register_thread(worker->bindid);
 	_starpu_worker_start(worker, fut_key, sync);
 #endif
-
 	_starpu_memory_node_set_local_key(&worker->memory_node);
 
 	_starpu_set_local_worker_key(worker);
@@ -1193,7 +1192,7 @@ int starpu_initialize(struct starpu_conf *user_conf, int *argc, char ***argv)
 #   ifdef STARPU_USE_MPI_MASTER_SLAVE
 	/* In MPI case we look at the rank to know if we are a sink */
 	if (_starpu_mpi_common_mp_init() && !_starpu_mpi_common_is_src_node())
-		setenv("STARPU_SINK", "STARPU_MPI", 1);
+		setenv("STARPU_SINK", "STARPU_MPI_MS", 1);
 #   endif
 
 	/* If StarPU was configured to use MP sinks, we have to control the

+ 5 - 1
src/datawizard/coherency.c

@@ -146,7 +146,8 @@ int _starpu_select_src_node(starpu_data_handle_t handle, unsigned destination)
 
 			if (starpu_node_get_kind(i) == STARPU_CPU_RAM || 
 			    starpu_node_get_kind(i) == STARPU_SCC_RAM ||
-			    starpu_node_get_kind(i) == STARPU_SCC_SHM)
+			    starpu_node_get_kind(i) == STARPU_SCC_SHM ||
+                starpu_node_get_kind(i) == STARPU_MPI_MS_RAM)
 				i_ram = i;
 			if (starpu_node_get_kind(i) == STARPU_DISK_RAM)			
 				i_disk = i;
@@ -263,6 +264,9 @@ static int worker_supports_direct_access(unsigned node, unsigned handling_node)
 		case STARPU_MIC_RAM:
 			/* TODO: We don't handle direct MIC-MIC transfers yet */
 			return 0;
+        case STARPU_MPI_MS_RAM:
+            /* Don't support MPI-MPI transfers yet */
+            return starpu_node_get_kind(handling_node) != STARPU_MPI_MS_RAM;
 		case STARPU_SCC_RAM:
 			return 1;
 		default:

+ 48 - 0
src/datawizard/copy_driver.c

@@ -22,6 +22,8 @@
 #include <datawizard/datastats.h>
 #include <datawizard/memory_nodes.h>
 #include <drivers/disk/driver_disk.h>
+#include <drivers/mpi/driver_mpi_sink.h>
+#include <drivers/mpi/driver_mpi_source.h>
 #include <common/fxt.h>
 #include "copy_driver.h"
 #include "memalloc.h"
@@ -420,6 +422,28 @@ static int copy_data_1_to_1_generic(starpu_data_handle_t handle,
 		break;
 	/* TODO: MIC -> MIC */
 #endif
+#ifdef STARPU_USE_MPI_MASTER_SLAVE
+	case _STARPU_MEMORY_NODE_TUPLE(STARPU_CPU_RAM,STARPU_MPI_MS_RAM):
+	//	if (copy_methods->mpi_ms_src_to_sink)
+	//		copy_methods->mpi_ms_src_to_sink(src_interface, src_node, dst_interface, dst_node);
+	//	else
+			copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, NULL);
+		break;
+
+	case _STARPU_MEMORY_NODE_TUPLE(STARPU_MPI_MS_RAM,STARPU_CPU_RAM):
+	//	if (copy_methods->mpi_ms_sink_to_src)
+	//		copy_methods->mpi_ms_sink_to_src(src_interface, src_node, dst_interface, dst_node);
+	//	else
+			copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, NULL);
+		break;
+
+	case _STARPU_MEMORY_NODE_TUPLE(STARPU_MPI_MS_RAM,STARPU_MPI_MS_RAM):
+	//	if (copy_methods->mpi_ms_sink_to_sink)
+	//		copy_methods->mpi_ms_sink_to_sink(src_interface, src_node, dst_interface, dst_node);
+	//	else
+			copy_methods->any_to_any(src_interface, src_node, dst_interface, dst_node, NULL);
+		break;
+#endif
 #ifdef STARPU_USE_SCC
 		/* SCC RAM associated to the master process is considered as
 		 * the main memory node. */
@@ -662,6 +686,30 @@ int starpu_interface_copy(uintptr_t src, size_t src_offset, unsigned src_node, u
 				(void*) (dst + dst_offset), dst_node,
 				size);
 #endif
+#ifdef STARPU_USE_MPI_MASTER_SLAVE
+    case _STARPU_MEMORY_NODE_TUPLE(STARPU_CPU_RAM, STARPU_MPI_MS_RAM):
+        /* TODO ASYNC */
+        return _starpu_mpi_copy_ram_to_mpi(
+                (void*) (src + src_offset), src_node,
+                (void*) (dst + dst_offset), dst_node,
+                size);
+
+    case _STARPU_MEMORY_NODE_TUPLE(STARPU_MPI_MS_RAM, STARPU_CPU_RAM):
+        /* TODO ASYNC */
+        return _starpu_mpi_copy_mpi_to_ram(
+                (void*) (src + src_offset), src_node,
+                (void*) (dst + dst_offset), dst_node,
+                size);
+
+    case _STARPU_MEMORY_NODE_TUPLE(STARPU_MPI_MS_RAM, STARPU_MPI_MS_RAM):
+        /* TODO : not used now + ASYNC */
+        STARPU_ABORT();
+        return _starpu_mpi_copy_sink_to_sink(
+                (void*) (src + src_offset), src_node,
+                (void*) (dst + dst_offset), dst_node,
+                size);
+#endif
+
 	case _STARPU_MEMORY_NODE_TUPLE(STARPU_CPU_RAM, STARPU_DISK_RAM):
 	{
 		return _starpu_disk_copy_src_to_disk(

+ 11 - 0
src/datawizard/malloc.c

@@ -591,6 +591,12 @@ _starpu_malloc_on_node(unsigned dst_node, size_t size, int flags)
 				addr = 0;
 			break;
 #endif
+#ifdef STARPU_USE_MPI_MASTER_SLAVE
+		case STARPU_MPI_MS_RAM:
+			if (_starpu_mpi_src_allocate_memory((void **)(&addr), size, dst_node))
+				addr = 0;
+			break;
+#endif
 #ifdef STARPU_USE_SCC
 		case STARPU_SCC_RAM:
 			if (_starpu_scc_allocate_memory((void **)(&addr), size, dst_node))
@@ -692,6 +698,11 @@ _starpu_free_on_node_flags(unsigned dst_node, uintptr_t addr, size_t size, int f
 			_starpu_mic_free_memory((void*) addr, size, dst_node);
 			break;
 #endif
+#ifdef STARPU_USE_MPI_MASTER_SLAVE
+        case STARPU_MPI_MS_RAM:
+            _starpu_mpi_source_free_memory((void*) addr, dst_node);
+            break;
+#endif
 #ifdef STARPU_USE_SCC
 		case STARPU_SCC_RAM:
 			_starpu_scc_free_memory((void *) addr, dst_node);

+ 3 - 0
src/datawizard/memory_nodes.c

@@ -92,6 +92,9 @@ void _starpu_memory_node_get_name(unsigned node, char *name, int size)
 	case STARPU_MIC_RAM:
 		prefix = "MIC";
 		break;
+	case STARPU_MPI_MS_RAM:
+		prefix = "MPI_MS";
+		break;
 	case STARPU_SCC_RAM:
 		prefix = "SCC_RAM";
 		break;

+ 2 - 0
src/drivers/mp_common/mp_common.c

@@ -271,6 +271,8 @@ _starpu_mp_common_node_create(enum _starpu_mp_node_kind node_kind,
 		node->mp_recv = _starpu_mpi_common_recv;
 		node->dt_send = _starpu_mpi_common_send;
 		node->dt_recv = _starpu_mpi_common_recv;
+        node->dt_send_to_device = _starpu_mpi_common_send_to_device;
+        node->dt_recv_from_device = _starpu_mpi_common_recv_from_device;
 
 		node->get_kernel_from_job = NULL;
 		node->lookup = _starpu_mpi_sink_lookup;

+ 7 - 2
src/drivers/mp_common/sink_common.c

@@ -45,7 +45,7 @@ static enum _starpu_mp_node_kind _starpu_sink_common_get_kind(void)
 		return STARPU_MIC_SINK;
 	else if (!strcmp(node_kind, "STARPU_SCC"))
 		return STARPU_SCC_SINK;
-	else if (!strcmp(node_kind, "STARPU_MPI"))
+	else if (!strcmp(node_kind, "STARPU_MPI_MS"))
 		return STARPU_MPI_SINK;
 	else
 		return STARPU_INVALID_KIND;
@@ -123,8 +123,13 @@ static void _starpu_sink_common_copy_to_host(const struct _starpu_mp_node *mp_no
 	STARPU_ASSERT(arg_size == sizeof(struct _starpu_mp_transfer_command));
 
 	struct _starpu_mp_transfer_command *cmd = (struct _starpu_mp_transfer_command *)arg;
+    /* Save values before sending command to prevent the overwriting */
+    size_t size = cmd->size;
+    void * addr = cmd->addr;
 
-	mp_node->dt_send(mp_node, cmd->addr, cmd->size);
+	_starpu_mp_common_send_command(mp_node, STARPU_SEND_TO_HOST, NULL, 0);
+    
+	mp_node->dt_send(mp_node, addr, size);
 }
 
 static void _starpu_sink_common_copy_from_sink(const struct _starpu_mp_node *mp_node,

+ 41 - 2
src/drivers/mp_common/source_common.c

@@ -35,6 +35,10 @@ struct starpu_save_thread_env
     struct _starpu_worker * current_worker;
     struct _starpu_worker_set * current_worker_set;
     unsigned * current_mem_node;
+#ifdef STARPU_OPENMP
+    struct starpu_omp_thread * current_omp_thread;
+    struct starpu_omp_task * current_omp_task;
+#endif
 };
 
 struct starpu_save_thread_env save_thread_env[STARPU_MAXMPIDEVS];
@@ -149,10 +153,14 @@ static void _starpu_src_common_handle_stored_async(struct _starpu_mp_node *node)
 	{
 		/* We pop a message and handle it */
 		struct mp_message * message = mp_message_list_pop_back(&node->message_queue);
+        /* Release mutex during handle */
+	    STARPU_PTHREAD_MUTEX_UNLOCK(&node->message_queue_mutex);
 		_starpu_src_common_handle_async(node, message->buffer,
 				message->size, message->type);
 		free(message->buffer);
 		mp_message_delete(message);
+        /* Take it again */
+	    STARPU_PTHREAD_MUTEX_LOCK(&node->message_queue_mutex);
 	}
 	STARPU_PTHREAD_MUTEX_UNLOCK(&node->message_queue_mutex);
 }
@@ -476,7 +484,7 @@ int _starpu_src_common_allocate(struct _starpu_mp_node *mp_node,
 
 	STARPU_ASSERT(answer == STARPU_ANSWER_ALLOCATE &&
 			arg_size == sizeof(*addr));
-
+    
 	memcpy(addr, arg, arg_size);
 
 	return 0;
@@ -499,6 +507,7 @@ int _starpu_src_common_copy_host_to_sink(const struct _starpu_mp_node *mp_node,
 	struct _starpu_mp_transfer_command cmd = {size, dst};
 
 	_starpu_mp_common_send_command(mp_node, STARPU_RECV_FROM_HOST, &cmd, sizeof(cmd));
+
 	mp_node->dt_send(mp_node, src, size);
 
 	return 0;
@@ -509,9 +518,17 @@ int _starpu_src_common_copy_host_to_sink(const struct _starpu_mp_node *mp_node,
 int _starpu_src_common_copy_sink_to_host(const struct _starpu_mp_node *mp_node,
 		void *src, void *dst, size_t size)
 {
+    enum _starpu_mp_command answer;
+	void *arg;
+	int arg_size;
 	struct _starpu_mp_transfer_command cmd = {size, src};
 
 	_starpu_mp_common_send_command(mp_node, STARPU_SEND_TO_HOST, &cmd, sizeof(cmd));
+
+    answer = _starpu_src_common_wait_command_sync(mp_node, &arg, &arg_size);
+     
+    STARPU_ASSERT(answer == STARPU_SEND_TO_HOST);
+
 	mp_node->dt_recv(mp_node, dst, size);
 
 	return 0;
@@ -660,17 +677,39 @@ int _starpu_src_common_locate_file(char *located_file_name,
 
 
 #if defined(STARPU_USE_MPI_MASTER_SLAVE) && !defined(STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD)
+
+void _starpu_src_common_init_switch_env(unsigned this)
+{
+    save_thread_env[this].current_task = starpu_task_get_current();
+    save_thread_env[this].current_worker = STARPU_PTHREAD_GETSPECIFIC(_starpu_worker_key);
+    save_thread_env[this].current_worker_set = STARPU_PTHREAD_GETSPECIFIC(_starpu_worker_set_key);
+    save_thread_env[this].current_mem_node = STARPU_PTHREAD_GETSPECIFIC(_starpu_memory_node_key);
+#ifdef STARPU_OPENMP
+    save_thread_env[this].current_omp_thread = STARPU_PTHREAD_GETSPECIFIC(omp_thread_key);
+    save_thread_env[this].current_omp_task = STARPU_PTHREAD_GETSPECIFIC(omp_task_key);
+#endif
+}
+
 static void _starpu_src_common_switch_env(unsigned old, unsigned new)
 {
     save_thread_env[old].current_task = starpu_task_get_current();
     save_thread_env[old].current_worker = STARPU_PTHREAD_GETSPECIFIC(_starpu_worker_key);
     save_thread_env[old].current_worker_set = STARPU_PTHREAD_GETSPECIFIC(_starpu_worker_set_key);
     save_thread_env[old].current_mem_node = STARPU_PTHREAD_GETSPECIFIC(_starpu_memory_node_key);
+#ifdef STARPU_OPENMP
+    save_thread_env[old].current_omp_thread = STARPU_PTHREAD_GETSPECIFIC(omp_thread_key);
+    save_thread_env[old].current_omp_task = STARPU_PTHREAD_GETSPECIFIC(omp_task_key);
+#endif
+
 
     _starpu_set_current_task(save_thread_env[new].current_task);
     STARPU_PTHREAD_SETSPECIFIC(_starpu_worker_key, save_thread_env[new].current_worker);
     STARPU_PTHREAD_SETSPECIFIC(_starpu_worker_set_key, save_thread_env[new].current_worker_set);
     STARPU_PTHREAD_SETSPECIFIC(_starpu_memory_node_key, save_thread_env[new].current_mem_node);
+#ifdef STARPU_OPENMP
+    STARPU_PTHREAD_SETSPECIFIC(omp_thread_key, save_thread_env[new].current_omp_thread);
+    STARPU_PTHREAD_SETSPECIFIC(omp_task_key, save_thread_env[new].current_omp_task); 
+#endif
 }
 #endif
 
@@ -801,8 +840,8 @@ void _starpu_src_common_workers_set(struct _starpu_worker_set * worker_set,
 	{
         for (device = 0; device < ndevices ; device++)
         {
+            _starpu_src_common_switch_env((device-1)%ndevices, device);
             _starpu_src_common_worker_internal_work(&worker_set[device], mp_node[device], tasks+offsetmemnode[device], memnode[device]);
-            _starpu_src_common_switch_env(device, (device+1)%ndevices);
         }
     }
 	free(tasks);

+ 1 - 0
src/drivers/mp_common/source_common.h

@@ -76,6 +76,7 @@ void _starpu_src_common_worker(struct _starpu_worker_set * worker_set,
 			       struct _starpu_mp_node * node_set);
 
 #if defined(STARPU_USE_MPI_MASTER_SLAVE) && !defined(STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD)
+void _starpu_src_common_init_switch_env(unsigned this);
 void _starpu_src_common_workers_set(struct _starpu_worker_set * worker_set,
                  int ndevices,
                  struct _starpu_mp_node ** mp_node);

+ 26 - 16
src/drivers/mpi/driver_mpi_common.c

@@ -131,17 +131,25 @@ void _starpu_mpi_common_mp_initialize_src_sink(struct _starpu_mp_node *node)
 
 int _starpu_mpi_common_recv_is_ready(const struct _starpu_mp_node *mp_node)
 {
-    int res, tag;
+    int res, tag, source;
     int flag = 0;
     int id_proc;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
     if (id_proc == src_node_id)
+    {
+        /* Source has mp_node defined */
         tag = mp_node->mp_connection.mpi_remote_nodeid;
+        source = mp_node->mp_connection.mpi_remote_nodeid;
+    }
     else
-        tag = id_proc;
-    
-    res = MPI_Iprobe(MPI_ANY_SOURCE, tag, MPI_COMM_WORLD, &flag, MPI_STATUS_IGNORE);
+    {
+        /* Sink can have sink to sink message */
+        tag = MPI_ANY_TAG;
+        source = MPI_ANY_SOURCE;
+    }
+        
+    res = MPI_Iprobe(source, tag, MPI_COMM_WORLD, &flag, MPI_STATUS_IGNORE);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot test if we received a message !");
 
     return flag;
@@ -150,7 +158,6 @@ int _starpu_mpi_common_recv_is_ready(const struct _starpu_mp_node *mp_node)
 /* SEND to source node */
 void _starpu_mpi_common_send(const struct _starpu_mp_node *node, void *msg, int len)
 {
-    printf("envoi %d B to %d \n", len, node->mp_connection.mpi_remote_nodeid);
     int res, tag;
     int id_proc;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
@@ -160,6 +167,8 @@ void _starpu_mpi_common_send(const struct _starpu_mp_node *node, void *msg, int
     else
         tag = id_proc;
 
+    printf("envoi %d B to %d et tag %d\n", len, node->mp_connection.mpi_remote_nodeid, tag);
+
     res = MPI_Send(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, tag, MPI_COMM_WORLD);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
 }
@@ -167,9 +176,9 @@ void _starpu_mpi_common_send(const struct _starpu_mp_node *node, void *msg, int
 /* RECV to source node */
 void _starpu_mpi_common_recv(const struct _starpu_mp_node *node, void *msg, int len)
 {
-    printf("recv %d B from %d \n", len, node->mp_connection.mpi_remote_nodeid);
     int res, tag;
     int id_proc;
+    MPI_Status s;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
     if (id_proc == src_node_id)
@@ -177,7 +186,13 @@ void _starpu_mpi_common_recv(const struct _starpu_mp_node *node, void *msg, int
     else
         tag = id_proc;
 
-    res = MPI_Recv(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
+    printf("recv %d B from %d in %p et tag %d\n", len, node->mp_connection.mpi_remote_nodeid, msg, tag);
+
+    res = MPI_Recv(msg, len, MPI_BYTE, node->mp_connection.mpi_remote_nodeid, tag, MPI_COMM_WORLD, &s);
+    int num_expected;
+    MPI_Get_count(&s, MPI_BYTE, &num_expected);
+
+    STARPU_ASSERT_MSG(num_expected == len, "MPI Master/Slave received a msg with a size of %d Bytes (expected %d Bytes) !", num_expected, len);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
 }
 
@@ -188,10 +203,8 @@ void _starpu_mpi_common_send_to_device(const struct _starpu_mp_node *node, int d
     int id_proc;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
-    if (id_proc == src_node_id)
-        tag = node->mp_connection.mpi_remote_nodeid;
-    else
-        tag = id_proc;
+    tag = dst_devid;
+    printf("send %d bytes from %d from %p et tag %d\n", len, dst_devid, msg, tag);
 
     res = MPI_Send(msg, len, MPI_BYTE, dst_devid, tag, MPI_COMM_WORLD);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
@@ -204,10 +217,8 @@ void _starpu_mpi_common_recv_from_device(const struct _starpu_mp_node *node, int
     int id_proc;
     MPI_Comm_rank(MPI_COMM_WORLD, &id_proc);
 
-    if (id_proc == src_node_id)
-        tag = node->mp_connection.mpi_remote_nodeid;
-    else
-        tag = id_proc;
+    tag = src_devid;
+    printf("nop recv %d bytes from %d et tag %d\n", len, src_devid, tag);
 
     res = MPI_Recv(msg, len, MPI_BYTE, src_devid, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
     STARPU_ASSERT_MSG(res == MPI_SUCCESS, "MPI Master/Slave cannot receive a msg with a size of %d Bytes !", len);
@@ -321,6 +332,5 @@ void _starpu_mpi_common_measure_bandwidth_latency(double * bandwidth_htod, doubl
 
         id++;
     }
-    
     free(buf);
 }

+ 1 - 0
src/drivers/mpi/driver_mpi_sink.c

@@ -20,6 +20,7 @@
 #include <dlfcn.h>
 
 #include "driver_mpi_sink.h"
+#include "driver_mpi_source.h"
 #include "driver_mpi_common.h"
 
 void _starpu_mpi_sink_init(struct _starpu_mp_node *node)

+ 51 - 1
src/drivers/mpi/driver_mpi_source.c

@@ -23,6 +23,8 @@
 #include <drivers/mpi/driver_mpi_source.h>
 #include <drivers/mpi/driver_mpi_common.h>
 
+#include <datawizard/memory_nodes.h>
+
 #include <drivers/driver_common/driver_common.h>
 #include <drivers/mp_common/source_common.h>
 
@@ -58,6 +60,52 @@ void _starpu_mpi_source_deinit(struct _starpu_mp_node *node)
 
 }
 
+struct _starpu_mp_node *_starpu_mpi_src_get_mp_node_from_memory_node(int memory_node)
+{
+    int devid = _starpu_memory_node_get_devid(memory_node);
+    STARPU_ASSERT_MSG(devid >= 0 && devid < STARPU_MAXMPIDEVS, "bogus devid %d for memory node %d\n", devid, memory_node);
+
+    return mpi_ms_nodes[devid];
+}
+
+int _starpu_mpi_src_allocate_memory(void ** addr, size_t size, unsigned memory_node)
+{
+    const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
+    return _starpu_src_common_allocate(mp_node, addr, size);
+}
+
+void _starpu_mpi_source_free_memory(void *addr, unsigned memory_node)
+{
+	const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(memory_node);
+    _starpu_src_common_free(mp_node, addr);
+}
+
+ /* Transfert SIZE bytes from the address pointed by SRC in the SRC_NODE memory
+  * node to the address pointed by DST in the DST_NODE memory node
+  */
+int _starpu_mpi_copy_ram_to_mpi(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size)
+{
+    const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(dst_node);
+    return _starpu_src_common_copy_host_to_sink(mp_node, src, dst, size);
+}   
+ 
+ /* Transfert SIZE bytes from the address pointed by SRC in the SRC_NODE memory
+  * node to the address pointed by DST in the DST_NODE memory node
+  */    
+int _starpu_mpi_copy_mpi_to_ram(void *src, unsigned src_node, void *dst, unsigned dst_node STARPU_ATTRIBUTE_UNUSED, size_t size)
+{
+    const struct _starpu_mp_node *mp_node = _starpu_mpi_src_get_mp_node_from_memory_node(src_node);
+    return _starpu_src_common_copy_sink_to_host(mp_node, src, dst, size);
+}   
+
+int _starpu_mpi_copy_sink_to_sink(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size)
+{
+    return _starpu_src_common_copy_sink_to_sink(_starpu_mpi_src_get_mp_node_from_memory_node(src_node),
+            _starpu_mpi_src_get_mp_node_from_memory_node(dst_node),
+            src, dst, size);
+}
+
+
 int _starpu_mpi_ms_src_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name)
 {
 	unsigned int func_name_size = (strlen(func_name) + 1) * sizeof(char);
@@ -149,7 +197,7 @@ void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node
 		/* If user dont define any starpu_mpi_ms_fun_t in cl->mpi_ms_func we try to use
 		 * cpu_func_name.
 		 */
-		char *func_name = _starpu_task_get_cpu_name_nth_implementation(j->task->cl, j->nimpl);
+		const char *func_name = _starpu_task_get_cpu_name_nth_implementation(j->task->cl, j->nimpl);
 		if (func_name)
 		{
 			starpu_mpi_ms_func_symbol_t symbol;
@@ -216,6 +264,7 @@ void *_starpu_mpi_src_worker(void *arg)
         /* unsigned memnode = baseworker->memory_node; */
 
         _starpu_driver_start(baseworker, _STARPU_FUT_MPI_KEY, 0);
+
 #ifdef STARPU_USE_FXT             
         for (i = 1; i < worker_set->nworkers; i++)
             _starpu_worker_start(&worker_set->workers[i], _STARPU_FUT_MPI_KEY, 0);
@@ -252,6 +301,7 @@ void *_starpu_mpi_src_worker(void *arg)
         }
     
 #ifndef STARPU_MPI_MASTER_SLAVE_MULTIPLE_THREAD
+        _starpu_src_common_init_switch_env(workersetnum);
     }  /* for */
 
     /* set the worker zero for the main thread */

+ 8 - 0
src/drivers/mpi/driver_mpi_source.h

@@ -26,6 +26,7 @@
 /* Array of structures containing all the informations useful to send
  * and receive informations with devices */
 extern struct _starpu_mp_node *mpi_ms_nodes[STARPU_MAXMICDEVS];
+struct _starpu_mp_node *_starpu_mpi_src_get_mp_node_from_memory_node(int memory_node);
 
 unsigned _starpu_mpi_src_get_device_count();
 void *_starpu_mpi_src_worker(void *arg);
@@ -34,6 +35,13 @@ void _starpu_mpi_exit_useless_node(int devid);
 void _starpu_mpi_source_init(struct _starpu_mp_node *node);
 void _starpu_mpi_source_deinit(struct _starpu_mp_node *node);
 
+int _starpu_mpi_src_allocate_memory(void ** addr, size_t size, unsigned memory_node);
+void _starpu_mpi_source_free_memory(void *addr, unsigned memory_node);
+
+int _starpu_mpi_copy_mpi_to_ram(void *src, unsigned src_node, void *dst, unsigned dst_node STARPU_ATTRIBUTE_UNUSED, size_t size);
+int _starpu_mpi_copy_ram_to_mpi(void *src, unsigned src_node STARPU_ATTRIBUTE_UNUSED, void *dst, unsigned dst_node, size_t size);
+int _starpu_mpi_copy_sink_to_sink(void *src, unsigned src_node, void *dst, unsigned dst_node, size_t size);
+
 void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void);
 
 #endif /* STARPU_USE_MPI_MASTER_SLAVE */