mpi_reduction_kernels.c 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. /* StarPU --- Runtime system for heterogeneous multicore architectures.
  2. *
  3. * Copyright (C) 2012,2017 Inria
  4. * Copyright (C) 2012,2013,2015,2017 CNRS
  5. * Copyright (C) 2013 Université de Bordeaux
  6. *
  7. * StarPU is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU Lesser General Public License as published by
  9. * the Free Software Foundation; either version 2.1 of the License, or (at
  10. * your option) any later version.
  11. *
  12. * StarPU is distributed in the hope that it will be useful, but
  13. * WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  15. *
  16. * See the GNU Lesser General Public License in COPYING.LGPL for more details.
  17. */
  18. #include <starpu.h>
  19. #include <mpi.h>
  20. #include "helper.h"
  21. /*
  22. * Codelet to create a neutral element
  23. */
  24. void init_cpu_func(void *descr[], void *cl_arg)
  25. {
  26. (void)cl_arg;
  27. long int *dot = (long int *)STARPU_VARIABLE_GET_PTR(descr[0]);
  28. *dot = 0;
  29. FPRINTF_MPI(stderr, "Init dot\n");
  30. }
  31. /*
  32. * Codelet to perform the reduction of two elements
  33. */
  34. void redux_cpu_func(void *descr[], void *cl_arg)
  35. {
  36. (void)cl_arg;
  37. long int *dota = (long int *)STARPU_VARIABLE_GET_PTR(descr[0]);
  38. long int *dotb = (long int *)STARPU_VARIABLE_GET_PTR(descr[1]);
  39. *dota = *dota + *dotb;
  40. FPRINTF_MPI(stderr, "Calling redux %ld=%ld+%ld\n", *dota, *dota-*dotb, *dotb);
  41. }
  42. /*
  43. * Dot product codelet
  44. */
  45. void dot_cpu_func(void *descr[], void *cl_arg)
  46. {
  47. (void)cl_arg;
  48. long int *local_x = (long int *)STARPU_VECTOR_GET_PTR(descr[0]);
  49. unsigned n = STARPU_VECTOR_GET_NX(descr[0]);
  50. long int *dot = (long int *)STARPU_VARIABLE_GET_PTR(descr[1]);
  51. //FPRINTF_MPI(stderr, "Before dot=%ld (adding %d elements...)\n", *dot, n);
  52. unsigned i;
  53. for (i = 0; i < n; i++)
  54. {
  55. //FPRINTF_MPI(stderr, "Adding %ld\n", local_x[i]);
  56. *dot += local_x[i];
  57. }
  58. //FPRINTF_MPI(stderr, "After dot=%ld\n", *dot);
  59. }
  60. /*
  61. * Display codelet
  62. */
  63. void display_cpu_func(void *descr[], void *cl_arg)
  64. {
  65. (void)cl_arg;
  66. long int *local_x = (long int *)STARPU_VARIABLE_GET_PTR(descr[0]);
  67. FPRINTF_MPI(stderr, "Local=%ld\n", *local_x);
  68. }