Przeglądaj źródła

Implement "detached" MPI requests which need not be waited or tested
explicitely. For now, we use a very simple active polling mechanism.

Cédric Augonnet 15 lat temu
rodzic
commit
0e22fdfbda
4 zmienionych plików z 240 dodań i 11 usunięć
  1. 15 1
      mpi/Makefile.am
  2. 127 9
      mpi/starpu_mpi.c
  3. 3 1
      mpi/starpu_mpi.h
  4. 95 0
      mpi/tests/mpi_irecv_detached.c

+ 15 - 1
mpi/Makefile.am

@@ -14,7 +14,7 @@
 # See the GNU Lesser General Public License in COPYING.LGPL for more details.
 #
 
-CC=mpicc
+CC=/opt/mpich2-mx/bin/mpicc
 
 if USE_CUDA
 # TODO define NVCCFLAGS
@@ -50,6 +50,8 @@ mpiexamplebin_PROGRAMS =				\
 	tests/pingpong					\
 	tests/mpi_test					\
 	tests/mpi_isend					\
+	tests/mpi_isend_detached			\
+	tests/mpi_irecv_detached			\
 	tests/mpi_irecv					\
 	tests/ring					\
 	tests/ring_async				\
@@ -62,6 +64,18 @@ tests_mpi_isend_LDADD =					\
 tests_mpi_isend_SOURCES =				\
 	tests/mpi_isend.c
 
+tests_mpi_isend_detached_LDADD =			\
+	libstarpumpi.la
+
+tests_mpi_isend_detached_SOURCES =			\
+	tests/mpi_isend_detached.c
+
+tests_mpi_irecv_detached_LDADD =			\
+	libstarpumpi.la
+
+tests_mpi_irecv_detached_SOURCES =			\
+	tests/mpi_irecv_detached.c
+
 tests_mpi_irecv_LDADD =					\
 	libstarpumpi.la
 

+ 127 - 9
mpi/starpu_mpi.c

@@ -23,6 +23,10 @@ static void handle_request_termination(struct starpu_mpi_req_s *req);
 /* The list of requests that have been newly submitted by the application */
 static starpu_mpi_req_list_t new_requests; 
 
+/* The list of detached requests that have already been submitted to MPI */
+static starpu_mpi_req_list_t detached_requests;
+static pthread_mutex_t detached_requests_mutex;
+
 static pthread_cond_t cond;
 static pthread_mutex_t mutex;
 static pthread_t progress_thread;
@@ -78,9 +82,35 @@ int starpu_mpi_isend(starpu_data_handle data_handle, struct starpu_mpi_req_s *re
  *	Isend (detached)
  */
 
-int starpu_mpi_isend_detached(starpu_data_handle data_handle, struct starpu_mpi_req_s *req, int dest, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
+int starpu_mpi_isend_detached(starpu_data_handle data_handle, struct starpu_mpi_req_s *req,
+				int dest, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
 {
-	/* TODO */
+	STARPU_ASSERT(req);
+
+	memset(req, 0, sizeof(struct starpu_mpi_req_s));
+
+	/* Initialize the request structure */
+	req->submitted = 0;
+	req->completed = 0;
+	pthread_mutex_init(&req->req_mutex, NULL);
+	pthread_cond_init(&req->req_cond, NULL);
+
+	req->data_handle = data_handle;
+	req->srcdst = dest;
+	req->mpi_tag = mpi_tag;
+	req->comm = comm;
+	req->func = starpu_mpi_isend_func;
+
+	req->detached = 1;
+	req->callback = callback;
+	req->callback_arg = arg;
+
+	/* Asynchronously request StarPU to fetch the data in main memory: when
+	 * it is available in main memory, submit_mpi_req(req) is called and
+	 * the request is actually submitted  */
+	starpu_sync_data_with_mem_non_blocking(data_handle, STARPU_R,
+			submit_mpi_req, (void *)req);
+
 	return 0;
 }
 
@@ -131,6 +161,42 @@ int starpu_mpi_irecv(starpu_data_handle data_handle, struct starpu_mpi_req_s *re
 }
 
 /*
+ *	Irecv (detached)
+ */
+
+int starpu_mpi_irecv_detached(starpu_data_handle data_handle, struct starpu_mpi_req_s *req, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
+{
+	STARPU_ASSERT(req);
+
+	memset(req, 0, sizeof(struct starpu_mpi_req_s));
+
+	/* Initialize the request structure */
+	req->submitted = 0;
+	pthread_mutex_init(&req->req_mutex, NULL);
+	pthread_cond_init(&req->req_cond, NULL);
+
+	req->data_handle = data_handle;
+	req->srcdst = source;
+	req->mpi_tag = mpi_tag;
+	req->comm = comm;
+
+	req->detached = 1;
+	req->callback = callback;
+	req->callback_arg = arg;
+
+	req->func = starpu_mpi_irecv_func;
+
+	/* Asynchronously request StarPU to fetch the data in main memory: when
+	 * it is available in main memory, submit_mpi_req(req) is called and
+	 * the request is actually submitted  */
+	starpu_sync_data_with_mem_non_blocking(data_handle, STARPU_W,
+			submit_mpi_req, (void *)req);
+
+	return 0;
+}
+
+
+/*
  *	Recv
  */
 
@@ -285,13 +351,16 @@ void handle_request_termination(struct starpu_mpi_req_s *req)
 	MPI_Type_free(&req->datatype);
 	starpu_release_data_from_mem(req->data_handle);
 
+	/* Execute the specified callback, if any */
+	if (req->callback)
+		req->callback(req->callback_arg);
+
 	/* tell anyone potentiallly waiting on the request that it is
 	 * terminated now */
 	pthread_mutex_lock(&req->req_mutex);
 	req->completed = 1;
 	pthread_cond_broadcast(&req->req_cond);
 	pthread_mutex_unlock(&req->req_mutex);
-	
 }
 
 void submit_mpi_req(void *arg)
@@ -308,12 +377,51 @@ void submit_mpi_req(void *arg)
  *	Progression loop
  */
 
+void test_detached_requests(void)
+{
+	int flag;
+	MPI_Status status;
+	struct starpu_mpi_req_s *req, *next_req;
+
+	pthread_mutex_lock(&detached_requests_mutex);
+
+	for (req = starpu_mpi_req_list_begin(detached_requests);
+		req != starpu_mpi_req_list_end(detached_requests);
+		req = next_req)
+	{
+		next_req = starpu_mpi_req_list_next(req);
+
+		pthread_mutex_unlock(&detached_requests_mutex);
+
+		MPI_Test(&req->request, &flag, &status);
+		if (flag)
+			handle_request_termination(req);
+
+		pthread_mutex_lock(&detached_requests_mutex);
+
+		if (flag)
+			starpu_mpi_req_list_erase(detached_requests, req);
+	}
+	
+	pthread_mutex_unlock(&detached_requests_mutex);
+}
+
 void handle_new_request(struct starpu_mpi_req_s *req)
 {
 	STARPU_ASSERT(req);
 
 	/* submit the request to MPI */
 	req->func(req);
+
+	if (req->detached)
+	{
+		/* put the submitted request into the list of pending requests
+		 * so that it can be handled by the progression mechanisms */
+		pthread_mutex_lock(&mutex);
+		starpu_mpi_req_list_push_front(detached_requests, req);
+		pthread_cond_signal(&cond);
+		pthread_mutex_unlock(&mutex);
+	}
 }
 
 void *progress_thread_func(void *arg __attribute__((unused)))
@@ -326,11 +434,17 @@ void *progress_thread_func(void *arg __attribute__((unused)))
 
 	pthread_mutex_lock(&mutex);
 	while (running) {
-		/* TODO test if there is some "detached request" and progress if this is the case */
-		pthread_cond_wait(&cond, &mutex);
+		if (starpu_mpi_req_list_empty(new_requests) && starpu_mpi_req_list_empty(detached_requests))
+			pthread_cond_wait(&cond, &mutex);
+
 		if (!running)
 			break;		
 
+		/* test whether there are some terminated "detached request" */
+		pthread_mutex_unlock(&mutex);
+		test_detached_requests();
+		pthread_mutex_lock(&mutex);
+
 		/* get one request */
 		struct starpu_mpi_req_s *req;
 		while (!starpu_mpi_req_list_empty(new_requests))
@@ -360,13 +474,15 @@ int starpu_mpi_initialize(void)
 {
 	pthread_mutex_init(&mutex, NULL);
 	pthread_cond_init(&cond, NULL);
-
 	new_requests = starpu_mpi_req_list_new();
 
+	pthread_mutex_init(&detached_requests_mutex, NULL);
+	detached_requests = starpu_mpi_req_list_new();
+
 	int ret = pthread_create(&progress_thread, NULL, progress_thread_func, NULL);
 
 	pthread_mutex_lock(&mutex);
-	if (!running)
+	while (!running)
 		pthread_cond_wait(&cond, &mutex);
 	pthread_mutex_unlock(&mutex);
 
@@ -375,16 +491,18 @@ int starpu_mpi_initialize(void)
 
 int starpu_mpi_shutdown(void)
 {
+	void *value;
+
 	/* kill the progression thread */
 	pthread_mutex_lock(&mutex);
 	running = 0;
-	pthread_cond_signal(&cond);
+	pthread_cond_broadcast(&cond);
 	pthread_mutex_unlock(&mutex);
 
-	void *value;
 	pthread_join(progress_thread, &value);
 
 	/* liberate the request queues */
+	starpu_mpi_req_list_delete(detached_requests);
 	starpu_mpi_req_list_delete(new_requests);
 
 	return 0;

+ 3 - 1
mpi/starpu_mpi.h

@@ -54,7 +54,7 @@ LIST_TYPE(starpu_mpi_req,
 
 	/* in the case of detached requests */
 	unsigned detached;
-	void *arg;
+	void *callback_arg;
 	void (*callback)(void *);
 );
 
@@ -66,6 +66,8 @@ int starpu_mpi_send(starpu_data_handle data_handle,
 		int dest, int mpi_tag, MPI_Comm comm);
 int starpu_mpi_recv(starpu_data_handle data_handle,
 		int source, int mpi_tag, MPI_Comm comm, MPI_Status *status);
+int starpu_mpi_isend_detached(starpu_data_handle data_handle, struct starpu_mpi_req_s *req, int dest, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg);
+int starpu_mpi_irecv_detached(starpu_data_handle data_handle, struct starpu_mpi_req_s *req, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg);
 int starpu_mpi_wait(struct starpu_mpi_req_s *req, MPI_Status *status);
 int starpu_mpi_test(struct starpu_mpi_req_s *req, int *flag, MPI_Status *status);
 int starpu_mpi_initialize(void);

+ 95 - 0
mpi/tests/mpi_irecv_detached.c

@@ -0,0 +1,95 @@
+/*
+ * StarPU
+ * Copyright (C) INRIA 2008-2009 (see AUTHORS file)
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation; either version 2.1 of the License, or (at
+ * your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ *
+ * See the GNU Lesser General Public License in COPYING.LGPL for more details.
+ */
+
+#include <starpu_mpi.h>
+
+#define NITER	2048
+#define SIZE	16
+
+float *tab;
+starpu_data_handle tab_handle;
+
+static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
+static pthread_cond_t cond = PTHREAD_COND_INITIALIZER;
+
+void callback(void *arg __attribute__((unused)))
+{
+	unsigned *received = arg;
+	
+	pthread_mutex_lock(&mutex);
+	*received = 1;
+	pthread_cond_signal(&cond);
+	pthread_mutex_unlock(&mutex);
+}
+
+
+int main(int argc, char **argv)
+{
+	MPI_Init(NULL, NULL);
+
+	int rank, size;
+
+	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+	MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+	if (size != 2)
+	{
+		if (rank == 0)
+			fprintf(stderr, "We need exactly 2 processes.\n");
+
+		MPI_Finalize();
+		return 0;
+	}
+
+	starpu_init(NULL);
+	starpu_mpi_initialize();
+
+	tab = malloc(SIZE*sizeof(float));
+
+	starpu_register_vector_data(&tab_handle, 0, (uintptr_t)tab, SIZE, sizeof(float));
+
+	unsigned nloops = NITER;
+	unsigned loop;
+
+	int other_rank = (rank + 1)%2;
+
+	for (loop = 0; loop < nloops; loop++)
+	{
+		if (rank == 0)
+		{
+			starpu_mpi_send(tab_handle, other_rank, loop, MPI_COMM_WORLD);
+		}
+		else {
+			MPI_Status status;
+			struct starpu_mpi_req_s req;
+
+			int received = 0;
+			starpu_mpi_irecv_detached(tab_handle, &req, other_rank, loop, MPI_COMM_WORLD, callback, &received);
+
+			pthread_mutex_lock(&mutex);
+			while (!received)
+				pthread_cond_wait(&cond, &mutex);
+			pthread_mutex_unlock(&mutex);
+		}
+	}
+	
+	starpu_mpi_shutdown();
+	starpu_shutdown();
+
+	MPI_Finalize();
+
+	return 0;
+}