|
@@ -26,7 +26,25 @@
|
|
|
#include <drivers/driver_common/driver_common.h>
|
|
|
#include <drivers/mp_common/source_common.h>
|
|
|
|
|
|
+
|
|
|
+ */
|
|
|
+starpu_pthread_mutex_t htbl_mutex = STARPU_PTHREAD_MUTEX_INITIALIZER;
|
|
|
+
|
|
|
+
|
|
|
+ * a MPI MS device : its name, and its address on each device.
|
|
|
+ * If a kernel has been initialized, then a lookup has already been achieved and the
|
|
|
+ * device knows how to call it, else the host still needs to do a lookup.
|
|
|
+ */
|
|
|
+struct _starpu_mpi_ms_kernel
|
|
|
+{
|
|
|
+ UT_hash_handle hh;
|
|
|
+ char *name;
|
|
|
+ starpu_mpi_ms_kernel_t func[STARPU_MAXMPIDEVS];
|
|
|
+} *kernels;
|
|
|
+
|
|
|
|
|
|
+
|
|
|
+ * and receive informations with devices */
|
|
|
struct _starpu_mp_node *mpi_ms_nodes[STARPU_MAXMPIDEVS];
|
|
|
|
|
|
void _starpu_mpi_source_init(struct _starpu_mp_node *node)
|
|
@@ -40,6 +58,114 @@ void _starpu_mpi_source_deinit(struct _starpu_mp_node *node)
|
|
|
|
|
|
}
|
|
|
|
|
|
+int _starpu_mpi_ms_src_register_kernel(starpu_mpi_ms_func_symbol_t *symbol, const char *func_name)
|
|
|
+{
|
|
|
+ unsigned int func_name_size = (strlen(func_name) + 1) * sizeof(char);
|
|
|
+
|
|
|
+ STARPU_PTHREAD_MUTEX_LOCK(&htbl_mutex);
|
|
|
+ struct _starpu_mpi_ms_kernel *kernel;
|
|
|
+
|
|
|
+ HASH_FIND_STR(kernels, func_name, kernel);
|
|
|
+
|
|
|
+ if (kernel != NULL)
|
|
|
+ {
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
|
|
|
+
|
|
|
+ *symbol = kernel;
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ kernel = malloc(sizeof(*kernel));
|
|
|
+ if (kernel == NULL)
|
|
|
+ {
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
|
|
|
+ return -ENOMEM;
|
|
|
+ }
|
|
|
+
|
|
|
+ kernel->name = malloc(func_name_size);
|
|
|
+ if (kernel->name == NULL)
|
|
|
+ {
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
|
|
|
+ free(kernel);
|
|
|
+ return -ENOMEM;
|
|
|
+ }
|
|
|
+
|
|
|
+ memcpy(kernel->name, func_name, func_name_size);
|
|
|
+
|
|
|
+ HASH_ADD_STR(kernels, name, kernel);
|
|
|
+
|
|
|
+ unsigned int nb_mpi_devices = _starpu_mpi_src_get_device_count();
|
|
|
+ unsigned int i;
|
|
|
+ for (i = 0; i < nb_mpi_devices; ++i)
|
|
|
+ kernel->func[i] = NULL;
|
|
|
+
|
|
|
+ STARPU_PTHREAD_MUTEX_UNLOCK(&htbl_mutex);
|
|
|
+
|
|
|
+ *symbol = kernel;
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+starpu_mpi_ms_kernel_t _starpu_mpi_ms_src_get_kernel(starpu_mpi_ms_func_symbol_t symbol)
|
|
|
+{
|
|
|
+ int workerid = starpu_worker_get_id();
|
|
|
+
|
|
|
+
|
|
|
+ * which will handle the task */
|
|
|
+ if (workerid < 0)
|
|
|
+ return NULL;
|
|
|
+
|
|
|
+ int devid = starpu_worker_get_devid(workerid);
|
|
|
+
|
|
|
+ struct _starpu_mpi_ms_kernel *kernel = symbol;
|
|
|
+
|
|
|
+ if (kernel->func[devid] == NULL)
|
|
|
+ {
|
|
|
+ struct _starpu_mp_node *node = mpi_ms_nodes[devid];
|
|
|
+ int ret = _starpu_src_common_lookup(node, (void (**)(void))&kernel->func[devid], kernel->name);
|
|
|
+ if (ret)
|
|
|
+ return NULL;
|
|
|
+ }
|
|
|
+
|
|
|
+ return kernel->func[devid];
|
|
|
+}
|
|
|
+
|
|
|
+void(* _starpu_mpi_ms_src_get_kernel_from_job(const struct _starpu_mp_node *node STARPU_ATTRIBUTE_UNUSED, struct _starpu_job *j))(void)
|
|
|
+{
|
|
|
+ starpu_mpi_ms_kernel_t kernel = NULL;
|
|
|
+
|
|
|
+ starpu_mpi_ms_func_t func = _starpu_task_get_mpi_ms_nth_implementation(j->task->cl, j->nimpl);
|
|
|
+ if (func)
|
|
|
+ {
|
|
|
+
|
|
|
+ * pointer to the function to execute on the device, either specified
|
|
|
+ * directly by the user or by a call to starpu_mpi_ms_get_func().
|
|
|
+ */
|
|
|
+ kernel = func();
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+
|
|
|
+ * cpu_func_name.
|
|
|
+ */
|
|
|
+ char *func_name = _starpu_task_get_cpu_name_nth_implementation(j->task->cl, j->nimpl);
|
|
|
+ if (func_name)
|
|
|
+ {
|
|
|
+ starpu_mpi_ms_func_symbol_t symbol;
|
|
|
+
|
|
|
+ _starpu_mpi_ms_src_register_kernel(&symbol, func_name);
|
|
|
+
|
|
|
+ kernel = _starpu_mpi_ms_src_get_kernel(symbol);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ STARPU_ASSERT(kernel);
|
|
|
+
|
|
|
+ return (void (*)(void))kernel;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
unsigned _starpu_mpi_src_get_device_count()
|
|
|
{
|
|
|
int nb_mpi_devices;
|