Ver código fonte

send the return value from MPI slave to master

HE Kun 4 anos atrás
pai
commit
2a3ca2760f

+ 3 - 0
src/drivers/mp_common/mp_common.h

@@ -139,6 +139,9 @@ struct mp_task
 	void **interfaces;
 	unsigned nb_interfaces;
 	void *cl_arg;
+	unsigned cl_arg_size;
+	void *cl_ret;
+	unsigned cl_ret_size;
 	unsigned coreid;
 	enum starpu_codelet_type type;
 	int is_parallel_task;

+ 30 - 3
src/drivers/mp_common/sink_common.c

@@ -511,7 +511,7 @@ static void _starpu_sink_common_pre_execution_message(struct _starpu_mp_node *no
 	_starpu_sink_common_append_message(node, message);
 }
 
-/* Append to the message list a "STARPU_EXECUTION_COMPLETED" message
+/* Append to the message list a "STARPU_EXECUTION_COMPLETED" message and cl_ret
  */
 static void _starpu_sink_common_execution_completed_message(struct _starpu_mp_node *node, struct mp_task *task)
 {
@@ -521,10 +521,23 @@ static void _starpu_sink_common_execution_completed_message(struct _starpu_mp_no
 		message->type = STARPU_MP_COMMAND_EXECUTION_DETACHED_COMPLETED;
 	else
 		message->type = STARPU_MP_COMMAND_EXECUTION_COMPLETED;
-	_STARPU_MALLOC(message->buffer, sizeof(int));
-	*(int*) message->buffer = task->coreid;
+
 	message->size = sizeof(int);
 
+	/* If the user didn't give any cl_ret, there is no need to send it */
+	 if (task->cl_ret)
+	 {
+	 	STARPU_ASSERT(task->cl_ret_size);
+	 	message->size += task->cl_ret_size;
+	 }
+
+	_STARPU_MALLOC(message->buffer, message->size);
+
+	*(int*) message->buffer = task->coreid;
+
+	 if (task->cl_ret)
+	 	memcpy( message->buffer+sizeof(int), task->cl_ret, task->cl_ret_size);
+
 	/* Append the message to the queue */
 	_starpu_sink_common_append_message(node, message);
 }
@@ -602,8 +615,21 @@ static void _starpu_sink_common_execute_kernel(struct _starpu_mp_node *node, int
 	{
 		if (_starpu_get_disable_kernels() <= 0)
 		{
+			struct starpu_task s_task;
+			starpu_task_init(&s_task);
+
+			/*copy cl_arg and cl_arg_size from mp_task into starpu_task*/
+			(&s_task)->cl_arg=task->cl_arg;
+			(&s_task)->cl_arg_size=task->cl_arg_size;
+
+			_starpu_set_current_task(&s_task);
 			/* execute the task */
 			task->kernel(task->interfaces,task->cl_arg);
+			_starpu_set_current_task(NULL);
+
+			/*copy cl_ret and cl_ret_size from starpu_task into mp_task*/
+			task->cl_ret=(&s_task)->cl_ret;
+			task->cl_ret_size=(&s_task)->cl_ret_size;
 		}
 	}
 
@@ -756,6 +782,7 @@ void _starpu_sink_common_execute(struct _starpu_mp_node *node, void *arg, int ar
 		unsigned cl_arg_size = arg_size - (arg_ptr - (uintptr_t) arg);
 		_STARPU_MALLOC(task->cl_arg, cl_arg_size);
 		memcpy(task->cl_arg, (void *) arg_ptr, cl_arg_size);
+		task->cl_arg_size=cl_arg_size;
 	}
 	else
 		task->cl_arg = NULL;

+ 18 - 2
src/drivers/mp_common/source_common.c

@@ -88,15 +88,31 @@ static int _starpu_src_common_process_completed_job(struct _starpu_mp_node *node
 {
 	int coreid;
 
-	STARPU_ASSERT(sizeof(coreid) == arg_size);
+	uintptr_t arg_ptr = (uintptr_t) arg;
 
-	coreid = *(int *) arg;
+	coreid = *(int *) arg_ptr;
+	arg_ptr += sizeof(coreid);
 
 	struct _starpu_worker *worker = &workerset->workers[coreid];
 	struct _starpu_job *j = _starpu_get_job_associated_to_task(worker->current_task);
 
+	struct starpu_task *task = j->task;
+	STARPU_ASSERT(task);
+
 	struct _starpu_worker * old_worker = _starpu_get_local_worker_key();
 
+	/* Was cl_ret sent ? */
+	if (arg_size > arg_ptr - (uintptr_t) arg)
+	{
+		/* Copy cl_ret into the task */
+		unsigned cl_ret_size = arg_size - (arg_ptr - (uintptr_t) arg);
+		printf("cl_ret_size in master is %d\n", cl_ret_size);
+		_STARPU_MALLOC(task->cl_ret, cl_ret_size);
+		memcpy(task->cl_ret, (void *) arg_ptr, cl_ret_size);
+	}
+	else
+		task->cl_ret = NULL;
+
         /* if arg is not copied we release the mutex */
         if (!stored)
                 STARPU_PTHREAD_MUTEX_UNLOCK(&node->connection_mutex);