Quellcode durchsuchen

julia gemm: Adapt the number of blocks to the number of workers and set OMP_NUM_THREADS=1.

Pierre Huchant vor 5 Jahren
Ursprung
Commit
82ccc6c9d2
4 geänderte Dateien mit 16 neuen und 3 gelöschten Zeilen
  1. 4 2
      julia/examples/gemm/gemm.jl
  2. 2 1
      julia/examples/gemm/gemm.sh
  3. 5 0
      julia/src/StarPU.jl
  4. 5 0
      julia/src/translate_headers.jl

+ 4 - 2
julia/examples/gemm/gemm.jl

@@ -15,6 +15,7 @@
 #
 #
 using StarPU
 using StarPU
 using LinearAlgebra.BLAS
 using LinearAlgebra.BLAS
+using BenchmarkTools
 
 
 @target STARPU_CPU+STARPU_CUDA
 @target STARPU_CPU+STARPU_CUDA
 @codelet function gemm(A :: Matrix{Float32}, B :: Matrix{Float32}, C :: Matrix{Float32}, alpha :: Float32, beta :: Float32) :: Nothing
 @codelet function gemm(A :: Matrix{Float32}, B :: Matrix{Float32}, C :: Matrix{Float32}, alpha :: Float32, beta :: Float32) :: Nothing
@@ -54,7 +55,6 @@ function multiply_with_starpu(A :: Matrix{Float32}, B :: Matrix{Float32}, C :: M
                     end
                     end
                 end
                 end
             end
             end
-            starpu_task_wait_for_all()
             t=time_ns()-t
             t=time_ns()-t
             if (tmin==0 || tmin>t)
             if (tmin==0 || tmin>t)
                 tmin=t
                 tmin=t
@@ -135,8 +135,10 @@ end
 
 
 starpu_init()
 starpu_init()
 starpu_cublas_init()
 starpu_cublas_init()
+nblock_x = Int32(ceil(sqrt(starpu_worker_get_count())))
+nblock_y = nblock_x
 io=open(filename,"w")
 io=open(filename,"w")
-compute_times(io,64,512,4096,1,1)
+compute_times(io,64,512,4096,nblock_x,nblock_y)
 close(io)
 close(io)
 
 
 starpu_shutdown()
 starpu_shutdown()

+ 2 - 1
julia/examples/gemm/gemm.sh

@@ -15,7 +15,8 @@
 # See the GNU Lesser General Public License in COPYING.LGPL for more details.
 # See the GNU Lesser General Public License in COPYING.LGPL for more details.
 #
 #
 
 
-$(dirname $0)/../execute.sh gemm/gemm.jl
 $(dirname $0)/../execute.sh gemm/gemm_native.jl
 $(dirname $0)/../execute.sh gemm/gemm_native.jl
 
 
+export OMP_NUM_THREADS=1
+$(dirname $0)/../execute.sh gemm/gemm.jl
 
 

+ 5 - 0
julia/src/StarPU.jl

@@ -112,5 +112,10 @@ export starpu_data_get_default_sequential_consistency_flag
 export starpu_data_set_default_sequential_consistency_flag
 export starpu_data_set_default_sequential_consistency_flag
 export starpu_data_get_sequential_consistency_flag
 export starpu_data_get_sequential_consistency_flag
 export starpu_data_set_sequential_consistency_flag
 export starpu_data_set_sequential_consistency_flag
+export starpu_worker_get_count
+export starpu_cpu_worker_get_count
+export starpu_cuda_worker_get_count
+export starpu_opencl_worker_get_count
+export starpu_mic_worker_get_count
 
 
 end
 end

+ 5 - 0
julia/src/translate_headers.jl

@@ -85,6 +85,11 @@ function starpu_translate_headers()
                                "starpu_task_declare_deps_array",
                                "starpu_task_declare_deps_array",
                                "starpu_iteration_push",
                                "starpu_iteration_push",
                                "starpu_iteration_pop",
                                "starpu_iteration_pop",
+                               "starpu_worker_get_count",
+                               "starpu_cpu_worker_get_count",
+                               "starpu_cuda_worker_get_count",
+                               "starpu_opencl_worker_get_count",
+                               "starpu_mic_worker_get_count",
                                "STARPU_CPU",
                                "STARPU_CPU",
                                "STARPU_CUDA",
                                "STARPU_CUDA",
                                "STARPU_CUDA_ASYNC",
                                "STARPU_CUDA_ASYNC",