|
@@ -41,23 +41,40 @@ static inline void chol_common_cpu_codelet_update_u22(void *descr[], int s, __at
|
|
|
unsigned ld12 = STARPU_MATRIX_GET_LD(descr[1]);
|
|
|
unsigned ld22 = STARPU_MATRIX_GET_LD(descr[2]);
|
|
|
|
|
|
- switch (s) {
|
|
|
- case 0:
|
|
|
+ if (s == 0)
|
|
|
+ {
|
|
|
+ int worker_size = starpu_combined_worker_get_size();
|
|
|
+
|
|
|
+ if (worker_size == 1)
|
|
|
+ {
|
|
|
+ /* Sequential CPU kernel */
|
|
|
SGEMM("N", "T", dy, dx, dz, -1.0f, left, ld21,
|
|
|
right, ld12, 1.0f, center, ld22);
|
|
|
- break;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ /* Parallel CPU kernel */
|
|
|
+ int rank = starpu_combined_worker_get_rank();
|
|
|
+
|
|
|
+ int block_size = (dx + worker_size - 1)/worker_size;
|
|
|
+ int new_dx = STARPU_MIN(dx, block_size*(rank+1)) - block_size*rank;
|
|
|
+
|
|
|
+ float *new_left = &left[block_size*rank];
|
|
|
+ float *new_center = ¢er[block_size*rank];
|
|
|
+
|
|
|
+ SGEMM("N", "T", dy, new_dx, dz, -1.0f, new_left, ld21,
|
|
|
+ right, ld12, 1.0f, new_center, ld22);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ /* CUDA kernel */
|
|
|
#ifdef STARPU_USE_CUDA
|
|
|
- case 1:
|
|
|
- cublasSgemm('n', 't', dy, dx, dz,
|
|
|
- -1.0f, left, ld21, right, ld12,
|
|
|
- 1.0f, center, ld22);
|
|
|
- cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
-
|
|
|
- break;
|
|
|
+ cublasSgemm('n', 't', dy, dx, dz,
|
|
|
+ -1.0f, left, ld21, right, ld12,
|
|
|
+ 1.0f, center, ld22);
|
|
|
+ cudaStreamSynchronize(starpu_cuda_get_local_stream());
|
|
|
#endif
|
|
|
- default:
|
|
|
- STARPU_ABORT();
|
|
|
- break;
|
|
|
+
|
|
|
}
|
|
|
}
|
|
|
|