Browse Source

Include the function to filter the CSR matrix directly in the example to make
it easier to understand.

Cédric Augonnet 14 years ago
parent
commit
a0f18ecea3
1 changed files with 34 additions and 1 deletions
  1. 34 1
      examples/spmv/spmv.c

+ 34 - 1
examples/spmv/spmv.c

@@ -40,9 +40,42 @@ static void parse_args(int argc, char **argv)
 	}
 }
 
+/* This filter function takes a CSR matrix, and divides it into nparts with the
+ * same number of non-zero entries. */
+static void csr_filter_func(void *father_interface, void *child_interface, struct starpu_data_filter *f, unsigned id, unsigned nparts)
+{
+	starpu_csr_interface_t *csr_father = father_interface;
+	starpu_csr_interface_t *csr_child = child_interface;
+
+	uint32_t nrow = csr_father->nrow;
+	size_t elemsize = csr_father->elemsize;
+	uint32_t firstentry = csr_father->firstentry;
+
+	/* Every sub-parts should contain the same number of non-zero entries */
+	uint32_t chunk_size = (nrow + nparts - 1)/nparts;
+	uint32_t *rowptr = csr_father->rowptr;
+
+	uint32_t first_index = id*chunk_size - firstentry;
+	uint32_t local_firstentry = rowptr[first_index];
+	
+	uint32_t child_nrow = STARPU_MIN(chunk_size, nrow - id*chunk_size);
+	uint32_t local_nnz = rowptr[first_index + child_nrow] - rowptr[first_index]; 
+	
+	csr_child->nnz = local_nnz;
+	csr_child->nrow = child_nrow;
+	csr_child->firstentry = local_firstentry;
+	csr_child->elemsize = elemsize;
+	
+	if (csr_father->nzval) {
+		csr_child->rowptr = &csr_father->rowptr[first_index];
+		csr_child->colind = &csr_father->colind[local_firstentry];
+		csr_child->nzval = csr_father->nzval + local_firstentry * elemsize;
+	}
+}
+
 /* partition the CSR matrix along a block distribution */
 static struct starpu_data_filter csr_f = {
-	.filter_func = starpu_vertical_block_filter_func_csr,
+	.filter_func = csr_filter_func,
 	/* This value is defined later on */
 	.nchildren = -1,
 	.get_nchildren = NULL,