瀏覽代碼

Check spmv computation

Samuel Thibault 5 年之前
父節點
當前提交
e703f1fb92
共有 1 個文件被更改,包括 28 次插入6 次删除
  1. 28 6
      examples/spmv/spmv.c

+ 28 - 6
examples/spmv/spmv.c

@@ -2,7 +2,7 @@
  *
  * Copyright (C) 2011-2013                                Inria
  * Copyright (C) 2010-2013,2015,2017                      CNRS
- * Copyright (C) 2009-2011,2013-2015                      Université de Bordeaux
+ * Copyright (C) 2009-2011,2013-2015,2020                 Université de Bordeaux
  * Copyright (C) 2010                                     Mehdi Juhoor
  *
  * StarPU is free software; you can redistribute it and/or modify
@@ -136,6 +136,7 @@ int main(int argc, char **argv)
 	/* Input and Output vectors */
 	float *vector_in_ptr;
 	float *vector_out_ptr;
+	float *vector_exp_out_ptr;
 
 	/*
 	 *	Parse command-line arguments
@@ -159,6 +160,9 @@ int main(int argc, char **argv)
 	starpu_malloc((void **)&rowptr, (size+1)*sizeof(uint32_t));
 	assert(nzval && colind && rowptr);
 
+#define UPPER_BAND 1.
+#define MIDDLE_BAND 5.
+#define LOWER_BAND 1.
 	/* fill the matrix */
 	for (row = 0, pos = 0; row < size; row++)
 	{
@@ -166,18 +170,18 @@ int main(int argc, char **argv)
 
 		if (row > 0)
 		{
-			nzval[pos] = 1.0f;
+			nzval[pos] = LOWER_BAND;
 			colind[pos] = row-1;
 			pos++;
 		}
 		
-		nzval[pos] = 5.0f;
+		nzval[pos] = MIDDLE_BAND;
 		colind[pos] = row;
 		pos++;
 
 		if (row < size - 1)
 		{
-			nzval[pos] = 1.0f;
+			nzval[pos] = UPPER_BAND;
 			colind[pos] = row+1;
 			pos++;
 		}
@@ -190,12 +194,13 @@ int main(int argc, char **argv)
 	/* initiate the 2 vectors */
 	starpu_malloc((void **)&vector_in_ptr, size*sizeof(float));
 	starpu_malloc((void **)&vector_out_ptr, size*sizeof(float));
-	assert(vector_in_ptr && vector_out_ptr);
+	starpu_malloc((void **)&vector_exp_out_ptr, size*sizeof(float));
+	assert(vector_in_ptr && vector_out_ptr && vector_exp_out_ptr);
 
 	/* fill them */
 	for (ind = 0; ind < size; ind++)
 	{
-		vector_in_ptr[ind] = 2.0f;
+		vector_in_ptr[ind] = ind % 100;
 		vector_out_ptr[ind] = 0.0f;
 	}
 
@@ -267,11 +272,28 @@ int main(int argc, char **argv)
                 FPRINTF(stdout, "%2.2f\t%2.2f\n", vector_in_ptr[row], vector_out_ptr[row]);
 	}
 
+	/* Check the result */
+	memset(vector_exp_out_ptr, 0, sizeof(vector_exp_out_ptr[0])*size);
+	for (row = 0; row < size; row++)
+	{
+		if (row > 0)
+			vector_exp_out_ptr[row] += LOWER_BAND * vector_in_ptr[row-1];
+		vector_exp_out_ptr[row] += MIDDLE_BAND * vector_in_ptr[row];
+		if (row < size-1)
+			vector_exp_out_ptr[row] += UPPER_BAND * vector_in_ptr[row+1];
+	}
+	for (row = 0; row < size; row++)
+		if (vector_out_ptr[row] != vector_exp_out_ptr[row]) {
+			FPRINTF(stderr, "check failed at %u: %f vs expected %f\n", row, vector_out_ptr[row], vector_exp_out_ptr[row]);
+			exit(EXIT_FAILURE);
+		}
+
 	starpu_free(nzval);
 	starpu_free(colind);
 	starpu_free(rowptr);
 	starpu_free(vector_in_ptr);
 	starpu_free(vector_out_ptr);
+	starpu_free(vector_exp_out_ptr);
 
 	/*
 	 *	Stop StarPU