| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516 | /* * StarPU * Copyright (C) INRIA 2008-2009 (see AUTHORS file) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation; either version 2.1 of the License, or (at * your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * * See the GNU Lesser General Public License in COPYING.LGPL for more details. */#include "strassen.h"#include "strassen_models.h"static starpu_data_handle create_tmp_matrix(starpu_data_handle M){	float *data;	starpu_data_handle state = malloc(sizeof(starpu_data_handle));	/* create a matrix with the same dimensions as M */	uint32_t nx = starpu_matrix_get_nx(M);	uint32_t ny = starpu_matrix_get_nx(M);	STARPU_ASSERT(state);	data = malloc(nx*ny*sizeof(float));	STARPU_ASSERT(data);	starpu_matrix_data_register(&state, 0, (uintptr_t)data, nx, nx, ny, sizeof(float));		return state;}static void free_tmp_matrix(starpu_data_handle matrix){	starpu_data_unregister(matrix);	free(matrix);}static void partition_matrices(strassen_iter_state_t *iter){	starpu_data_handle A = iter->A;	starpu_data_handle B = iter->B;	starpu_data_handle C = iter->C;	starpu_filter f;	f.filter_func = starpu_block_filter_func;	f.filter_arg = 2;	starpu_filter f2;	f2.filter_func = starpu_vertical_block_filter_func;	f2.filter_arg = 2;	starpu_map_filters(A, 2, &f, &f2);	starpu_map_filters(B, 2, &f, &f2);	starpu_map_filters(C, 2, &f, &f2);	iter->A11 = starpu_data_get_sub_data(A, 2, 0, 0);	iter->A12 = starpu_data_get_sub_data(A, 2, 1, 0);	iter->A21 = starpu_data_get_sub_data(A, 2, 0, 1);	iter->A22 = starpu_data_get_sub_data(A, 2, 1, 1);	iter->B11 = starpu_data_get_sub_data(B, 2, 0, 0);	iter->B12 = starpu_data_get_sub_data(B, 2, 1, 0);	iter->B21 = starpu_data_get_sub_data(B, 2, 0, 1);	iter->B22 = starpu_data_get_sub_data(B, 2, 1, 1);	iter->C11 = starpu_data_get_sub_data(C, 2, 0, 0);	iter->C12 = starpu_data_get_sub_data(C, 2, 1, 0);	iter->C21 = starpu_data_get_sub_data(C, 2, 0, 1);	iter->C22 = starpu_data_get_sub_data(C, 2, 1, 1);	/* TODO check that all sub-matrices have the same size */}static void unpartition_matrices(strassen_iter_state_t *iter){	/* TODO there is no  need to actually gather those results ... */	starpu_data_unpartition(iter->A, 0);	starpu_data_unpartition(iter->B, 0);	starpu_data_unpartition(iter->C, 0);}static starpu_codelet cl_add = {	.where = STARPU_CPU|STARPU_CUDA,	.model = &strassen_model_add_sub,	.cpu_func = add_cpu_codelet,#ifdef STARPU_USE_CUDA	.cuda_func = add_cublas_codelet,#endif	.nbuffers = 3};static starpu_codelet cl_sub = {	.where = STARPU_CPU|STARPU_CUDA,	.model = &strassen_model_add_sub,	.cpu_func = sub_cpu_codelet,#ifdef STARPU_USE_CUDA	.cuda_func = sub_cublas_codelet,#endif	.nbuffers = 3};static starpu_codelet cl_mult = {	.where = STARPU_CPU|STARPU_CUDA,	.model = &strassen_model_mult,	.cpu_func = mult_cpu_codelet,#ifdef STARPU_USE_CUDA	.cuda_func = mult_cublas_codelet,#endif	.nbuffers = 3};static starpu_codelet cl_self_add = {	.where = STARPU_CPU|STARPU_CUDA,	.model = &strassen_model_self_add_sub,	.cpu_func = self_add_cpu_codelet,#ifdef STARPU_USE_CUDA	.cuda_func = self_add_cublas_codelet,#endif	.nbuffers = 2};static starpu_codelet cl_self_sub = {	.where = STARPU_CPU|STARPU_CUDA,	.model = &strassen_model_self_add_sub,	.cpu_func = self_sub_cpu_codelet,#ifdef STARPU_USE_CUDA	.cuda_func = self_sub_cublas_codelet,#endif	.nbuffers = 2};static void compute_add_sub_op(starpu_data_handle A1, operation op,				starpu_data_handle A2, starpu_data_handle C, 				void (*callback)(void *), void *argcallback){	/* performs C = (A op B) */	struct starpu_task *task = starpu_task_create();		task->cl_arg = NULL;		task->use_tag = 0;	task->buffers[0].handle = C;	task->buffers[0].mode = STARPU_W;	task->buffers[1].handle = A1;	task->buffers[1].mode = STARPU_R;	task->buffers[2].handle = A2;	task->buffers[2].mode = STARPU_R;		task->callback_func = callback;	task->callback_arg = argcallback;	switch (op) {		case ADD:			STARPU_ASSERT(A1);			STARPU_ASSERT(A2);			STARPU_ASSERT(C);			task->cl = &cl_add;			break;		case SUB:			STARPU_ASSERT(A1);			STARPU_ASSERT(A2);			STARPU_ASSERT(C);			task->cl = &cl_sub;			break;		case MULT:			STARPU_ASSERT(A1);			STARPU_ASSERT(A2);			STARPU_ASSERT(C);			task->cl = &cl_mult;			break;		case SELFADD:			task->buffers[0].mode = STARPU_RW;			task->cl = &cl_self_add;			break;		case SELFSUB:			task->buffers[0].mode = STARPU_RW;			task->cl = &cl_self_sub;			break;		default:			STARPU_ABORT();	}	starpu_task_submit(task);}/* Cij +=/-= Ek is done */void phase_3_callback_function(void *_arg){	unsigned cnt, use_cnt;	phase3_t *arg = _arg;	unsigned i = arg->i;	strassen_iter_state_t *iter = arg->iter;	free(arg);	use_cnt = STARPU_ATOMIC_ADD(&iter->Ei_remaining_use[i], -1);	if (use_cnt == 0) 	{		/* no one needs Ei anymore : free it */		switch (i) {			case 0:				free_tmp_matrix(iter->E1);				break;			case 1:				free_tmp_matrix(iter->E2);				break;			case 2:				free_tmp_matrix(iter->E3);				break;			case 3:				free_tmp_matrix(iter->E4);				break;			case 4:				free_tmp_matrix(iter->E5);				break;			case 5:				free_tmp_matrix(iter->E6);				break;			case 6:				free_tmp_matrix(iter->E7);				break;			default:				STARPU_ABORT();		}	}	cnt = STARPU_ATOMIC_ADD(&iter->counter, -1);	if (cnt == 0)	{		/* the entire strassen iteration is done ! */		unpartition_matrices(iter);		// XXX free the Ei		STARPU_ASSERT(iter->strassen_iter_callback);		iter->strassen_iter_callback(iter->argcb);		free(iter);	}}/* Ei is computed */void phase_2_callback_function(void *_arg){	phase2_t *arg = _arg;	strassen_iter_state_t *iter = arg->iter;	unsigned i = arg->i;	free(arg);	phase3_t *arg1, *arg2;	arg1 = malloc(sizeof(phase3_t));	arg2 = malloc(sizeof(phase3_t));	arg1->iter = iter;	arg2->iter = iter;	arg1->i = i;	arg2->i = i;	switch (i) {		case 0:			free(arg2); // will not be needed .. 			free_tmp_matrix(iter->E11);			free_tmp_matrix(iter->E12);			/* C11 += E1 */			compute_add_sub_op(iter->E1, SELFADD, NULL, iter->C11, phase_3_callback_function, arg1);			break;		case 1:			free_tmp_matrix(iter->E21);			free_tmp_matrix(iter->E22);			/* C11 += E2 */			compute_add_sub_op(iter->E2, SELFADD, NULL, iter->C11, phase_3_callback_function, arg1);			/* C22 += E2 */			compute_add_sub_op(iter->E2, SELFADD, NULL, iter->C22, phase_3_callback_function, arg2);			break;		case 2:			free(arg2); // will not be needed .. 			free_tmp_matrix(iter->E31);			free_tmp_matrix(iter->E32);			/* C22 -= E3 */			compute_add_sub_op(iter->E3, SELFSUB, NULL, iter->C22, phase_3_callback_function, arg1);			break;		case 3:			free_tmp_matrix(iter->E41);			/* C11 -= E4 */			compute_add_sub_op(iter->E4, SELFSUB, NULL, iter->C11, phase_3_callback_function, arg1);			/* C12 += E4 */			compute_add_sub_op(iter->E4, SELFADD, NULL, iter->C12, phase_3_callback_function, arg2);			break;		case 4:			free_tmp_matrix(iter->E52);			/* C12 += E5 */			compute_add_sub_op(iter->E5, SELFADD, NULL, iter->C12, phase_3_callback_function, arg1);			/* C22 += E5 */			compute_add_sub_op(iter->E5, SELFADD, NULL, iter->C22, phase_3_callback_function, arg2);			break;		case 5:			free_tmp_matrix(iter->E62);			/* C11 += E6 */			compute_add_sub_op(iter->E6, SELFADD, NULL, iter->C11, phase_3_callback_function, arg1);			/* C21 += E6 */			compute_add_sub_op(iter->E6, SELFADD, NULL, iter->C21, phase_3_callback_function, arg2);			break;		case 6:			free_tmp_matrix(iter->E71);			/* C21 += E7 */			compute_add_sub_op(iter->E7, SELFADD, NULL, iter->C21, phase_3_callback_function, arg1);			/* C22 -= E7 */			compute_add_sub_op(iter->E7, SELFSUB, NULL, iter->C22, phase_3_callback_function, arg2);			break;		default:			STARPU_ABORT();	}}/* computes Ei */static void _strassen_phase_2(strassen_iter_state_t *iter, unsigned i){	phase2_t *phase_2_arg = malloc(sizeof(phase2_t));	phase_2_arg->iter = iter;	phase_2_arg->i = i;	/* XXX */	starpu_data_handle A;	starpu_data_handle B;	starpu_data_handle C;	switch (i) {		case 0:			A = iter->E11; B = iter->E12;			iter->E1 = create_tmp_matrix(A);			C = iter->E1;			break;		case 1:			A = iter->E21; B = iter->E22;			iter->E2 = create_tmp_matrix(A);			C = iter->E2;			break;		case 2:			A = iter->E31; B = iter->E32;			iter->E3 = create_tmp_matrix(A);			C = iter->E3;			break;		case 3:			A = iter->E41; B = iter->E42;			iter->E4 = create_tmp_matrix(A);			C = iter->E4;			break;		case 4:			A = iter->E51; B = iter->E52;			iter->E5 = create_tmp_matrix(A);			C = iter->E5;			break;		case 5:			A = iter->E61; B = iter->E62;			iter->E6 = create_tmp_matrix(A);			C = iter->E6;			break;		case 6:			A = iter->E71; B = iter->E72;			iter->E7 = create_tmp_matrix(A);			C = iter->E7;			break;		default:			STARPU_ABORT();	}	STARPU_ASSERT(A);	STARPU_ASSERT(B);	STARPU_ASSERT(C);	// DEBUG XXX	//compute_add_sub_op(A, MULT, B, C, phase_2_callback_function, phase_2_arg);	strassen(A, B, C, phase_2_callback_function, phase_2_arg, iter->reclevel-1);}#define THRESHHOLD	128static void phase_1_callback_function(void *_arg){	phase1_t *arg = _arg;	strassen_iter_state_t *iter = arg->iter;	unsigned i = arg->i;	free(arg);	unsigned cnt = STARPU_ATOMIC_ADD(&iter->Ei12[i], +1);	if (cnt == 2) {		/* Ei1 and Ei2 are ready, compute Ei */		_strassen_phase_2(iter, i);	}}/* computes Ei1 or Ei2 with i in 0-6 */static void _strassen_phase_1(starpu_data_handle A1, operation opA, starpu_data_handle A2,			      starpu_data_handle C, strassen_iter_state_t *iter, unsigned i){	phase1_t *phase_1_arg = malloc(sizeof(phase1_t));	phase_1_arg->iter = iter;	phase_1_arg->i = i;	compute_add_sub_op(A1, opA, A2, C, phase_1_callback_function, phase_1_arg);}strassen_iter_state_t *init_strassen_iter_state(starpu_data_handle A, starpu_data_handle B, starpu_data_handle C, void (*strassen_iter_callback)(void *), void *argcb){	strassen_iter_state_t *iter_state = malloc(sizeof(strassen_iter_state_t));	iter_state->Ei12[0] = 0;	iter_state->Ei12[1] = 0;	iter_state->Ei12[2] = 0;	iter_state->Ei12[3] = 1; // E42 = B22	iter_state->Ei12[4] = 1; // E51 = A11	iter_state->Ei12[5] = 1; // E61 = A22	iter_state->Ei12[6] = 1; // E72 = B11	iter_state->Ei_remaining_use[0] = 1; 	iter_state->Ei_remaining_use[1] = 2;	iter_state->Ei_remaining_use[2] = 1;	iter_state->Ei_remaining_use[3] = 2;	iter_state->Ei_remaining_use[4] = 2;	iter_state->Ei_remaining_use[5] = 2;	iter_state->Ei_remaining_use[6] = 2;	unsigned i;	for (i = 0; i < 6; i++)	{		iter_state->Ei[i] = 0;	}	for (i = 0; i < 4; i++)	{		iter_state->Cij[i] = 0;	}	iter_state->strassen_iter_callback = strassen_iter_callback;	iter_state->argcb = argcb;	iter_state->A = A;	iter_state->B = B;	iter_state->C = C;	iter_state->counter = 12;	return iter_state;}static void _do_strassen(starpu_data_handle A, starpu_data_handle B, starpu_data_handle C, void (*strassen_iter_callback)(void *), void *argcb, unsigned reclevel){	/* do one level of recursion in the strassen algorithm */	strassen_iter_state_t *iter = init_strassen_iter_state(A, B, C, strassen_iter_callback, argcb);	partition_matrices(iter);	iter->reclevel = reclevel;	/* some Eij are already known */	iter->E11 = create_tmp_matrix(iter->A11);	iter->E12 = create_tmp_matrix(iter->B21);	iter->E21 = create_tmp_matrix(iter->A11);	iter->E22 = create_tmp_matrix(iter->B11);	iter->E31 = create_tmp_matrix(iter->A11);	iter->E32 = create_tmp_matrix(iter->B11);	iter->E41 = create_tmp_matrix(iter->A11);	iter->E42 = iter->B22;	iter->E51 = iter->A11;	iter->E52 = create_tmp_matrix(iter->B12);	iter->E61 = iter->A22;	iter->E62 = create_tmp_matrix(iter->B21);	iter->E71 = create_tmp_matrix(iter->A21);	iter->E72 = iter->B11;	/* compute all Eij */	_strassen_phase_1(iter->A11, SUB, iter->A22, iter->E11, iter, 0);	_strassen_phase_1(iter->B21, ADD, iter->B22, iter->E12, iter, 0);	_strassen_phase_1(iter->A11, ADD, iter->A22, iter->E21, iter, 1);	_strassen_phase_1(iter->B11, ADD, iter->B22, iter->E22, iter, 1);	_strassen_phase_1(iter->A11, SUB, iter->A21, iter->E31, iter, 2);	_strassen_phase_1(iter->B11, ADD, iter->B12, iter->E32, iter, 2);	_strassen_phase_1(iter->A11, ADD, iter->A12, iter->E41, iter, 3);	_strassen_phase_1(iter->B12, SUB, iter->B22, iter->E52, iter, 4);	_strassen_phase_1(iter->B21, SUB, iter->B11, iter->E62, iter, 5);	_strassen_phase_1(iter->A21, ADD, iter->A22, iter->E71, iter, 6);}void strassen(starpu_data_handle A, starpu_data_handle B, starpu_data_handle C, void (*callback)(void *), void *argcb, unsigned reclevel){	/* C = A * B */	if ( reclevel == 0 )	{		/* don't use Strassen but a simple sequential multiplication		 * provided this is small enough */		compute_add_sub_op(A, MULT, B, C, callback, argcb);	}	else {		_do_strassen(A, B, C, callback, argcb, reclevel);	}}
 |