瀏覽代碼

blockized CUDA version of vector_scal

Samuel Thibault 14 年之前
父節點
當前提交
b990d1d43f
共有 1 個文件被更改,包括 6 次插入4 次删除
  1. 6 4
      examples/basic_examples/vector_scal_cuda.cu

+ 6 - 4
examples/basic_examples/vector_scal_cuda.cu

@@ -23,8 +23,9 @@
 static __global__ void vector_mult_cuda(float *val, unsigned n,
                                         float factor)
 {
-        unsigned i;
-        for(i = 0 ; i < n ; i++)
+        unsigned i = threadIdx.x;
+
+	if (i < n)
                val[i] *= factor;
 }
 
@@ -36,9 +37,10 @@ extern "C" void scal_cuda_func(void *buffers[], void *_args)
         unsigned n = STARPU_VECTOR_GET_NX(buffers[0]);
         /* local copy of the vector pointer */
         float *val = (float *)STARPU_VECTOR_GET_PTR(buffers[0]);
+	unsigned threads_per_block = 64;
+	unsigned nblocks = (n + threads_per_block-1) / threads_per_block;
 
-        /* TODO: use more blocks and threads in blocks */
-        vector_mult_cuda<<<1,1>>>(val, n, *factor);
+        vector_mult_cuda<<<nblocks,threads_per_block>>>(val, n, *factor);
 
 	cudaThreadSynchronize();
 }