Browse Source

Fix the SSE example.

Cyril Roelandt 13 years ago
parent
commit
dc3ce715e4
2 changed files with 29 additions and 11 deletions
  1. 15 6
      doc/chapters/vector_scal_cpu.texi
  2. 14 5
      examples/basic_examples/vector_scal_cpu.c

+ 15 - 6
doc/chapters/vector_scal_cpu.texi

@@ -45,15 +45,24 @@ void scal_sse_func(void *buffers[], void *cl_arg)
     float *vector = (float *) STARPU_VECTOR_GET_PTR(buffers[0]);
     unsigned int n = STARPU_VECTOR_GET_NX(buffers[0]);
     unsigned int n_iterations = n/4;
-    if (n % 4 != 0)
-        n_iterations++;
 
     __m128 *VECTOR = (__m128*) vector;
-    __m128 factor __attribute__((aligned(16)));
-    factor = _mm_set1_ps(*(float *) cl_arg);
+    __m128 FACTOR __attribute__((aligned(16)));
+    float factor = *(float *) cl_arg;
+    FACTOR = _mm_set1_ps(factor);
 
-    unsigned int i;
+    unsigned int i;	
     for (i = 0; i < n_iterations; i++)
-        VECTOR[i] = _mm_mul_ps(factor, VECTOR[i]);
+        VECTOR[i] = _mm_mul_ps(FACTOR, VECTOR[i]);
+
+    int remainder = n%4;
+    if (remainder != 0)
+    @{
+        int start = 4 * n_iterations;
+        for (i = start; i < start+remainder; ++i)
+        @{
+            vector[i] = factor * vector[i];
+        @}
+    @}
 @}
 @end smallexample

+ 14 - 5
examples/basic_examples/vector_scal_cpu.c

@@ -61,15 +61,24 @@ void scal_sse_func(void *buffers[], void *cl_arg)
 	float *vector = (float *) STARPU_VECTOR_GET_PTR(buffers[0]);
 	unsigned int n = STARPU_VECTOR_GET_NX(buffers[0]);
 	unsigned int n_iterations = n/4;
-	if (n % 4 != 0)
-		n_iterations++;
 
 	__m128 *VECTOR = (__m128*) vector;
-	__m128 factor __attribute__((aligned(16)));
-	factor = _mm_set1_ps(*(float *) cl_arg);
+	__m128 FACTOR __attribute__((aligned(16)));
+	float factor = *(float *) cl_arg;
+	FACTOR = _mm_set1_ps(factor);
 
 	unsigned int i;	
 	for (i = 0; i < n_iterations; i++)
-		VECTOR[i] = _mm_mul_ps(factor, VECTOR[i]);
+		VECTOR[i] = _mm_mul_ps(FACTOR, VECTOR[i]);
+
+	int remainder = n%4;
+	if (remainder != 0)
+	{
+		int start = 4 * n_iterations;
+		for (i = start; i < start+remainder; ++i)
+		{
+			vector[i] = factor * vector[i];
+		}
+	}
 }
 #endif