/* StarPU --- Runtime system for heterogeneous multicore architectures. * * Copyright (C) 2014-2020 Université de Bordeaux, CNRS (LaBRI UMR 5800), Inria * * StarPU is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation; either version 2.1 of the License, or (at * your option) any later version. * * StarPU is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * * See the GNU Lesser General Public License in COPYING.LGPL for more details. */ #include #include #include #include #include #include #include // Should be deduced at preprocessing (Nmad vs MPI) #include "starpu_mpi_cache.h" #define MAX_CP_TEMPLATE_NUMBER 32 // Arbitrary limit starpu_pthread_mutex_t cp_template_mutex; starpu_mpi_checkpoint_template_t cp_template_array[MAX_CP_TEMPLATE_NUMBER]; int my_rank; int cp_template_number = 0; static struct _starpu_mpi_req_list detached_ft_service_requests; static unsigned detached_send_n_ft_service_requests; static starpu_pthread_mutex_t detached_ft_service_requests_mutex; void _starpu_mpi_post_cp_ack_recv_cb(void* _args); void _starpu_mpi_post_cp_ack_send_cb(void* _args); void _starpu_mpi_treat_cache_ack_no_lock_cb(void* args); extern struct _starpu_mpi_req *_starpu_mpi_irecv_common(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count); extern struct _starpu_mpi_req *_starpu_mpi_isend_common(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg, int sequential_consistency); static int _starpu_mpi_checkpoint_template_register(starpu_mpi_checkpoint_template_t* cp_template, int cp_id, va_list varg_list) { int arg_type; //void* useless; void* ptr; int count; int backup_rank; int backup_of; // int (*_backup_of)(int); // int (*_backuped_by)(int); starpu_mpi_checkpoint_template_t _cp_template = _starpu_mpi_checkpoint_template_new(cp_id); va_list varg_list_copy; va_copy(varg_list_copy, varg_list); while ((arg_type = va_arg(varg_list_copy, int)) != 0) { STARPU_ASSERT_MSG(!(arg_type & STARPU_COMMUTE), "Unable to checkpoint non sequential task flow.\n"); switch(arg_type) { case STARPU_R: ptr = va_arg(varg_list_copy, void*); count = 1; backup_rank = va_arg(varg_list_copy, int); backup_of = -1; break; case STARPU_VALUE: ptr = va_arg(varg_list_copy, void*); count = va_arg(varg_list_copy, int); backup_rank = va_arg(varg_list_copy, int); backup_of = va_arg(varg_list_copy, int); break; // case STARPU_DATA_ARRAY: // ptr = va_arg(varg_list_copy, void*); // count = va_arg(varg_list_copy, int); // backup_rank = va_arg(varg_list_copy, int); // backup_of = -1; // break; default: STARPU_ABORT_MSG("Unrecognized argument %d, did you perhaps forget to end arguments with 0?\n", arg_type); break; } _starpu_mpi_checkpoint_template_add_data(_cp_template, arg_type, ptr, count, backup_rank, backup_of); }; va_end(varg_list_copy); _starpu_mpi_checkpoint_template_freeze(_cp_template); starpu_pthread_mutex_lock(&cp_template_mutex); for (int i=0 ; icp_template_id != _cp_template->cp_template_id, "A checkpoint with id %d has already been registered.\n", _cp_template->cp_template_id); } cp_template_array[cp_template_number] = _cp_template; cp_template_number++; starpu_pthread_mutex_unlock(&cp_template_mutex); *cp_template = _cp_template; return 0; } struct _starpu_mpi_req* _starpu_mpi_irecv_cache_aware(starpu_data_handle_t data_handle, int source, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, void (*callback)(void *), void *arg, void (*alt_callback)(void *), void *alt_arg, int sequential_consistency, int is_internal_req, starpu_ssize_t count) { struct _starpu_mpi_req* req = NULL; int already_received = _starpu_mpi_cache_received_data_set(data_handle); if (already_received == 0) { if (data_tag == -1) _STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n"); _STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, source); req = _starpu_mpi_irecv_common(data_handle, source, data_tag, comm, detached, sync, callback, (void*)arg, sequential_consistency, is_internal_req, count); } else { fprintf(stderr, "STARPU CACHE: Data already received\n"); alt_callback(alt_arg); } return req; } struct _starpu_mpi_req* _starpu_mpi_isend_cache_aware(starpu_data_handle_t data_handle, int dest, starpu_mpi_tag_t data_tag, MPI_Comm comm, unsigned detached, unsigned sync, int prio, void (*callback)(void *), void *arg, void (*alt_callback)(void *), void *alt_arg, int sequential_consistency) { struct _starpu_mpi_req* req = NULL; int already_sent = _starpu_mpi_cache_sent_data_set(data_handle, dest); if (already_sent == 0) { if (data_tag == -1) _STARPU_ERROR("StarPU needs to be told the MPI tag of this data, using starpu_mpi_data_register\n"); _STARPU_MPI_DEBUG(1, "Receiving data %p from %d\n", data_handle, mpi_rank); req = _starpu_mpi_isend_common(data_handle, dest, data_tag, comm, detached, sync, prio, callback, (void*)arg, sequential_consistency); } else { fprintf(stderr, "STARPU CACHE: Data already sent\n"); alt_callback(alt_arg); } return req; } int _starpu_mpi_checkpoint_template_submit(starpu_mpi_checkpoint_template_t cp_template) { starpu_data_handle_t* handle; struct _starpu_mpi_checkpoint_template_item* item; //MPI_Comm comm; starpu_pthread_mutex_lock(&cp_template->mutex); STARPU_ASSERT_MSG(cp_template->pending==0, "Can not submit a checkpoint while previous instance has not succeeded.\n"); cp_template->pending = 1; cp_template->remaining_ack_awaited = cp_template->message_number; item = _starpu_mpi_checkpoint_template_get_first_data(cp_template); fprintf(stderr, "begin iter\n"); while (item != _starpu_mpi_checkpoint_template_end(cp_template)) { switch (item->type) { case STARPU_VALUE: // starpu_data_handle_t send_handle; // starpu_variable_data_register(&send_handle, STARPU_MAIN_RAM, (uintptr_t)item->ptr, item->count); // starpu_mpi_data_register(send_handle, ) // starpu_mpi_send break; case STARPU_R: handle = (starpu_data_handle_t*)item->ptr; if (starpu_mpi_data_get_rank(*handle)==my_rank) { fprintf(stderr,"sending to %d (tag %d)\n", item->backup_rank, (int)starpu_mpi_data_get_tag(*handle)); struct _starpu_mpi_cp_ack_arg_cb* arg = calloc(1, sizeof(struct _starpu_mpi_cp_ack_arg_cb)); arg->rank = item->backup_rank; arg->msg.checkpoint_id = cp_template->cp_template_id; arg->msg.checkpoint_instance = cp_template->cp_template_current_instance; _starpu_mpi_isend_cache_aware(*handle, item->backup_rank, starpu_mpi_data_get_tag(*handle), MPI_COMM_WORLD, 1, 0, 0, &_starpu_mpi_post_cp_ack_recv_cb, (void*)arg, &_starpu_mpi_treat_cache_ack_no_lock_cb, (void*)cp_template, 1); } else if (item->backup_rank==my_rank) { fprintf(stderr,"recving from %d (tag %d)\n", starpu_mpi_data_get_rank(*handle), (int)starpu_mpi_data_get_tag(*handle)); struct _starpu_mpi_cp_ack_arg_cb* arg = calloc(1, sizeof(struct _starpu_mpi_cp_ack_arg_cb)); arg->rank = starpu_mpi_data_get_rank(*handle); arg->msg.checkpoint_id = cp_template->cp_template_id; arg->msg.checkpoint_instance = cp_template->cp_template_current_instance; _starpu_mpi_irecv_cache_aware(*handle, starpu_mpi_data_get_rank(*handle), starpu_mpi_data_get_tag(*handle), MPI_COMM_WORLD, 1, 0, &_starpu_mpi_post_cp_ack_send_cb, (void*)arg, NULL, NULL, 1, 1, 1); } break; } item = _starpu_mpi_checkpoint_template_get_next_data(cp_template, item); }; starpu_pthread_mutex_unlock(&cp_template->mutex); return 0; } // ///** // * receives param of type starpu_mpi_checkpoint_template_t // * @param args // * @return // */ //void _starpu_mpi_checkpoint_ack_send_cb(void* args) //{ // starpu_mpi_checkpoint_template_t cp_template = (starpu_mpi_checkpoint_template_t) args; // starpu_pthread_mutex_lock(&cp_template->mutex); // cp_template->remaining_ack_awaited--; // starpu_pthread_mutex_unlock(&cp_template->mutex); //} // For test purpose int _starpu_mpi_checkpoint_template_print(starpu_mpi_checkpoint_template_t cp_template) { int val; int i = 0; struct _starpu_mpi_checkpoint_template_item* item = _starpu_mpi_checkpoint_template_get_first_data(cp_template); while (item != _starpu_mpi_checkpoint_template_end(cp_template)) { fprintf(stderr,"Item %2d: ", i); if (item->type == STARPU_VALUE) { fprintf(stderr, "STARPU_VALUE - "); fprintf(stderr, "Value=%d\n", (*(int *)(item->ptr))); } else if (item->type == STARPU_R) { val = *(int*)starpu_data_handle_to_pointer(*(starpu_data_handle_t*)(item->ptr), 0); fprintf(stderr, "STARPU_R - Value=%d\n", val); } else if (item->type == STARPU_DATA_ARRAY) { fprintf(stderr, "STARPU_DATA_ARRAY - Multiple values: %d", *(int*)starpu_data_handle_to_pointer(*((starpu_data_handle_t*)item->ptr), 0)); for (int j=1 ; jcount, 5) ; j++) { fprintf(stderr, ", %d", *(int*)starpu_data_handle_to_pointer(((starpu_data_handle_t*)item->ptr)[j], 0)); //j*sizeof(starpu_data_handle_t) } fprintf(stderr, "...\n"); } else { printf("Unrecognized type.\n"); } item = _starpu_mpi_checkpoint_template_get_next_data(cp_template, item); i++; }; return 0; } int starpu_mpi_checkpoint_turn_on(void) { starpu_pthread_mutex_init(&cp_template_mutex, NULL); _starpu_mpi_req_list_init(&detached_ft_service_requests); starpu_pthread_mutex_init(&detached_ft_service_requests_mutex, NULL); starpu_mpi_comm_rank(MPI_COMM_WORLD, &my_rank); //TODO: check compatibility with several Comms behaviour return 0; } int starpu_mpi_checkpoint_turn_off(void) { for (int i=0 ; imsg.checkpoint_id, arg->msg.checkpoint_instance, arg->rank); free(_args); } void _starpu_mpi_treat_cache_ack_no_lock_cb(void* args) { starpu_mpi_checkpoint_template_t cp_template = (starpu_mpi_checkpoint_template_t)args; cp_template->remaining_ack_awaited--; } void _starpu_mpi_treat_ack_receipt_cb(void* _args) { struct _starpu_mpi_cp_ack_msg* msg = (struct _starpu_mpi_cp_ack_msg*) _args; starpu_pthread_mutex_lock(&cp_template_mutex); for (int i=0 ; imutex); if (cp_template_array[i]->cp_template_id == msg->checkpoint_id && cp_template_array[i]->cp_template_current_instance == msg->checkpoint_instance) { cp_template_array[i]->remaining_ack_awaited--; if (cp_template_array[i]->remaining_ack_awaited == 0) { // TODO: share info about cp integrity fprintf(stderr, "All cp material for cpid:%d, cpinst:%d - have been sent and acknowledged.\n", msg->checkpoint_id, msg->checkpoint_instance); cp_template_array[i]->pending=0; } free(msg); starpu_pthread_mutex_unlock(&cp_template_array[i]->mutex); starpu_pthread_mutex_unlock(&cp_template_mutex); return; } starpu_pthread_mutex_unlock(&cp_template_array[i]->mutex); } starpu_pthread_mutex_unlock(&cp_template_mutex); } void _starpu_mpi_post_cp_ack_send_cb(void* _args) { struct _starpu_mpi_req* req; struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args; fprintf(stderr, "Send cb\n"); /* Initialize the request structure */ _starpu_mpi_request_init(&req); req->request_type = SEND_REQ; /* prio_list is sorted by increasing values */ if (_starpu_mpi_use_prio) req->prio = 0; req->data_handle = NULL; req->node_tag.node.rank = arg->rank; req->node_tag.data_tag = _STARPU_MPI_TAG_CP_ACK; req->node_tag.node.comm = MPI_COMM_WORLD; req->detached = 1; req->ptr = (void*)&arg->msg; req->sync = 0; req->datatype = MPI_BYTE; req->callback = _print_ack_sent_cb; req->callback_arg = arg; req->func = NULL; req->sequential_consistency = 1; req->count = sizeof(struct _starpu_mpi_cp_ack_msg); _mpi_backend._starpu_mpi_backend_request_fill(req, MPI_COMM_WORLD, 0); STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex); MPI_Isend(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag, req->node_tag.node.comm, &req->backend->data_request); _starpu_mpi_req_list_push_back(&detached_ft_service_requests, req); fprintf(stderr, "pushed send: %p in list %p - prev: %p - next: %p - dest:%d - tag:%d\n", req, &detached_ft_service_requests, _starpu_mpi_req_list_prev(req), _starpu_mpi_req_list_next(req), req->node_tag.node.rank, (int)req->node_tag.data_tag); detached_send_n_ft_service_requests++; req->submitted = 1; STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex); } void _starpu_mpi_post_cp_ack_recv_cb(void* _args) { struct _starpu_mpi_req* req; struct _starpu_mpi_cp_ack_arg_cb* arg = (struct _starpu_mpi_cp_ack_arg_cb*) _args; /* Initialize the request structure */ _starpu_mpi_request_init(&req); req->request_type = RECV_REQ; /* prio_list is sorted by increasing values */ if (_starpu_mpi_use_prio) req->prio = 0; req->data_handle = NULL; req->node_tag.node.rank = arg->rank; req->node_tag.data_tag = _STARPU_MPI_TAG_CP_ACK; req->node_tag.node.comm = MPI_COMM_WORLD; req->detached = 1; req->ptr = malloc(sizeof(struct _starpu_mpi_cp_ack_msg)); req->sync = 0; req->datatype = MPI_BYTE; req->callback = _starpu_mpi_treat_ack_receipt_cb; req->callback_arg = req->ptr; req->func = NULL; req->sequential_consistency = 1; req->count = sizeof(struct _starpu_mpi_cp_ack_msg); _mpi_backend._starpu_mpi_backend_request_fill(req, MPI_COMM_WORLD, 0); STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex); MPI_Irecv(req->ptr, req->count, req->datatype, req->node_tag.node.rank, req->node_tag.data_tag, req->node_tag.node.comm, &req->backend->data_request); _starpu_mpi_req_list_push_back(&detached_ft_service_requests, req); fprintf(stderr, "pushed recv: %p in list %p - prev: %p - next: %p - src:%d - tag:%d\n", req, &detached_ft_service_requests, _starpu_mpi_req_list_prev(req), _starpu_mpi_req_list_next(req), req->node_tag.node.rank, (int)req->node_tag.data_tag); req->submitted = 1; STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex); } static void _starpu_mpi_handle_request_termination(struct _starpu_mpi_req *req) { _STARPU_MPI_LOG_IN(); _STARPU_MPI_DEBUG(2, "complete MPI request %p type %s tag %"PRIi64" src %d data %p ptr %p datatype '%s' count %d registered_datatype %d internal_req %p\n", req, _starpu_mpi_request_type(req->request_type), req->node_tag.data_tag, req->node_tag.node.rank, req->data_handle, req->ptr, req->datatype_name, (int)req->count, req->registered_datatype, req->backend->internal_req); if (req->backend->internal_req) { free(req->backend->early_data_handle); req->backend->early_data_handle = NULL; } else { if (req->request_type == RECV_REQ || req->request_type == SEND_REQ) { if (req->registered_datatype == 0) { if (req->request_type == SEND_REQ) { // We need to make sure the communication for sending the size // has completed, as MPI can re-order messages, let's call // MPI_Wait to make sure data have been sent int ret; ret = MPI_Wait(&req->backend->size_req, MPI_STATUS_IGNORE); STARPU_MPI_ASSERT_MSG(ret == MPI_SUCCESS, "MPI_Wait returning %s", _starpu_mpi_get_mpi_error_code(ret)); starpu_free_on_node_flags(STARPU_MAIN_RAM, (uintptr_t)req->ptr, req->count, 0); req->ptr = NULL; } else if (req->request_type == RECV_REQ) { // req->ptr is freed by starpu_data_unpack starpu_data_unpack(req->data_handle, req->ptr, req->count); starpu_memory_deallocate(STARPU_MAIN_RAM, req->count); } } else { //_starpu_mpi_datatype_free(req->data_handle, &req->datatype); } } _STARPU_MPI_TRACE_TERMINATED(req, req->node_tag.node.rank, req->node_tag.data_tag); } _starpu_mpi_release_req_data(req); if (req->backend->envelope) { free(req->backend->envelope); req->backend->envelope = NULL; } /* Execute the specified callback, if any */ if (req->callback) req->callback(req->callback_arg); /* tell anyone potentially waiting on the request that it is * terminated now */ STARPU_PTHREAD_MUTEX_LOCK(&req->backend->req_mutex); req->completed = 1; STARPU_PTHREAD_COND_BROADCAST(&req->backend->req_cond); STARPU_PTHREAD_MUTEX_UNLOCK(&req->backend->req_mutex); _STARPU_MPI_LOG_OUT(); } static void _starpu_mpi_test_ft_detached_requests(void) { //_STARPU_MPI_LOG_IN(); int flag; struct _starpu_mpi_req *req; STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex); if (_starpu_mpi_req_list_empty(&detached_ft_service_requests)) { STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex); //_STARPU_MPI_LOG_OUT(); return; } _STARPU_MPI_TRACE_TESTING_DETACHED_BEGIN(); req = _starpu_mpi_req_list_begin(&detached_ft_service_requests); while (req != _starpu_mpi_req_list_end(&detached_ft_service_requests)) { STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex); _STARPU_MPI_TRACE_TEST_BEGIN(req->node_tag.node.rank, req->node_tag.data_tag); //_STARPU_MPI_DEBUG(3, "Test detached request %p - mpitag %"PRIi64" - TYPE %s %d\n", &req->backend->data_request, req->node_tag.data_tag, _starpu_mpi_request_type(req->request_type), req->node_tag.node.rank); #ifdef STARPU_SIMGRID req->ret = _starpu_mpi_simgrid_mpi_test(&req->done, &flag); #else STARPU_MPI_ASSERT_MSG(req->backend->data_request != MPI_REQUEST_NULL, "Cannot test completion of the request MPI_REQUEST_NULL"); req->ret = MPI_Test(&req->backend->data_request, &flag, MPI_STATUS_IGNORE); #endif STARPU_MPI_ASSERT_MSG(req->ret == MPI_SUCCESS, "MPI_Test returning %s", _starpu_mpi_get_mpi_error_code(req->ret)); _STARPU_MPI_TRACE_TEST_END(req->node_tag.node.rank, req->node_tag.data_tag); if (!flag) { req = _starpu_mpi_req_list_next(req); } else { fprintf(stderr, "req success: %d\n", detached_send_n_ft_service_requests); _STARPU_MPI_TRACE_POLLING_END(); struct _starpu_mpi_req *next_req; next_req = _starpu_mpi_req_list_next(req); _STARPU_MPI_TRACE_COMPLETE_BEGIN(req->request_type, req->node_tag.node.rank, req->node_tag.data_tag); STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex); if (req->request_type == SEND_REQ) detached_send_n_ft_service_requests--; _starpu_mpi_req_list_erase(&detached_ft_service_requests, req); STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex); _starpu_mpi_handle_request_termination(req); _STARPU_MPI_TRACE_COMPLETE_END(req->request_type, req->node_tag.node.rank, req->node_tag.data_tag); STARPU_PTHREAD_MUTEX_LOCK(&req->backend->req_mutex); /* We don't want to free internal non-detached requests, we need to get their MPI request before destroying them */ if (req->backend->is_internal_req && !req->backend->to_destroy) { /* We have completed the request, let the application request destroy it */ req->backend->to_destroy = 1; STARPU_PTHREAD_MUTEX_UNLOCK(&req->backend->req_mutex); } else { STARPU_PTHREAD_MUTEX_UNLOCK(&req->backend->req_mutex); _starpu_mpi_request_destroy(req); } req = next_req; _STARPU_MPI_TRACE_POLLING_BEGIN(); } STARPU_PTHREAD_MUTEX_LOCK(&detached_ft_service_requests_mutex); } _STARPU_MPI_TRACE_TESTING_DETACHED_END(); STARPU_PTHREAD_MUTEX_UNLOCK(&detached_ft_service_requests_mutex); //_STARPU_MPI_LOG_OUT(); } void starpu_mpi_ft_progress(void) { _starpu_mpi_test_ft_detached_requests(); } int starpu_mpi_ft_busy() { return ! _starpu_mpi_req_list_empty(&detached_ft_service_requests); }