|
@@ -322,6 +322,33 @@ int starpu_mpi_recv(starpu_data_handle_t data_handle, int source, int mpi_tag, M
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
+static void _starpu_mpi_probe_func(struct _starpu_mpi_req *req)
|
|
|
+{
|
|
|
+ _STARPU_MPI_LOG_IN();
|
|
|
+
|
|
|
+ _starpu_mpi_handle_allocate_datatype(req->data_handle, &req->datatype, &req->user_datatype);
|
|
|
+#ifdef STARPU_DEVEL
|
|
|
+#warning TODO: release that assert
|
|
|
+#endif
|
|
|
+ assert(req->user_datatype == 0);
|
|
|
+ req->count = 1;
|
|
|
+ req->ptr = starpu_handle_get_local_ptr(req->data_handle);
|
|
|
+
|
|
|
+ _STARPU_MPI_DEBUG("MPI probe tag %d dst %d ptr %p datatype %p count %d req %p\n", req->mpi_tag, req->srcdst, req->ptr, req->datatype, (int)req->count, &req->request);
|
|
|
+
|
|
|
+ _starpu_mpi_handle_detached_request(req);
|
|
|
+
|
|
|
+ _STARPU_MPI_LOG_OUT();
|
|
|
+}
|
|
|
+
|
|
|
+int starpu_mpi_irecv_probe_detached(starpu_data_handle_t data_handle, int source, int mpi_tag, MPI_Comm comm, void (*callback)(void *), void *arg)
|
|
|
+{
|
|
|
+ _STARPU_MPI_LOG_IN();
|
|
|
+ _starpu_mpi_isend_irecv_common(data_handle, source, mpi_tag, comm, 1, callback, arg, PROBE_REQ, _starpu_mpi_probe_func, STARPU_W);
|
|
|
+ _STARPU_MPI_LOG_OUT();
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
/********************************************************/
|
|
|
/* */
|
|
|
/* Wait functionalities */
|
|
@@ -563,6 +590,7 @@ static char *_starpu_mpi_request_type(enum _starpu_mpi_request_type request_type
|
|
|
case WAIT_REQ: return "WAIT_REQ";
|
|
|
case TEST_REQ: return "TEST_REQ";
|
|
|
case BARRIER_REQ: return "BARRIER_REQ";
|
|
|
+ case PROBE_REQ: return "PROBE_REQ";
|
|
|
default: return "unknown request type";
|
|
|
}
|
|
|
}
|
|
@@ -573,7 +601,17 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)
|
|
|
_STARPU_MPI_LOG_IN();
|
|
|
|
|
|
_STARPU_MPI_DEBUG("complete MPI (%s %d) data %p req %p - tag %d\n", _starpu_mpi_request_type(req->request_type), req->srcdst, req->data_handle, &req->request, req->mpi_tag);
|
|
|
- if (req->request_type == RECV_REQ || req->request_type == SEND_REQ)
|
|
|
+ if (req->request_type == PROBE_REQ)
|
|
|
+ {
|
|
|
+#ifdef STARPU_DEVEL
|
|
|
+#warning TODO: instead of calling MPI_Recv, we should post a starpu mpi recv request
|
|
|
+#endif
|
|
|
+ MPI_Status status;
|
|
|
+ memset(&status, 0, sizeof(MPI_Status));
|
|
|
+ req->ret = MPI_Recv(req->ptr, req->count, req->datatype, req->srcdst, req->mpi_tag, req->comm, &status);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (req->request_type == RECV_REQ || req->request_type == SEND_REQ || req->request_type == PROBE_REQ)
|
|
|
{
|
|
|
if (req->user_datatype == 1)
|
|
|
{
|
|
@@ -590,7 +628,7 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)
|
|
|
starpu_data_release(req->data_handle);
|
|
|
}
|
|
|
|
|
|
- if (req->request_type == RECV_REQ)
|
|
|
+ if (req->request_type == RECV_REQ || req->request_type == PROBE_REQ)
|
|
|
{
|
|
|
TRACE_MPI_IRECV_END(req->srcdst, req->mpi_tag);
|
|
|
}
|
|
@@ -599,7 +637,7 @@ static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req)
|
|
|
if (req->callback)
|
|
|
req->callback(req->callback_arg);
|
|
|
|
|
|
- /* tell anyone potentiallly waiting on the request that it is
|
|
|
+ /* tell anyone potentially waiting on the request that it is
|
|
|
* terminated now */
|
|
|
_STARPU_PTHREAD_MUTEX_LOCK(&req->req_mutex);
|
|
|
req->completed = 1;
|
|
@@ -659,7 +697,15 @@ static void _starpu_mpi_test_detached_requests(void)
|
|
|
_STARPU_PTHREAD_MUTEX_UNLOCK(&detached_requests_mutex);
|
|
|
|
|
|
//_STARPU_MPI_DEBUG("Test detached request %p - mpitag %d - TYPE %s %d\n", &req->request, req->mpi_tag, _starpu_mpi_request_type(req->request_type), req->srcdst);
|
|
|
- req->ret = MPI_Test(&req->request, &flag, &status);
|
|
|
+ if (req->request_type == PROBE_REQ)
|
|
|
+ {
|
|
|
+ req->ret = MPI_Iprobe(req->srcdst, req->mpi_tag, req->comm, &flag, &status);
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ req->ret = MPI_Test(&req->request, &flag, &status);
|
|
|
+ }
|
|
|
+
|
|
|
STARPU_ASSERT(req->ret == MPI_SUCCESS);
|
|
|
|
|
|
if (flag)
|