瀏覽代碼

Provide wrappers for the GEMV blas functions.

Cédric Augonnet 15 年之前
父節點
當前提交
73c80890fd
共有 2 個文件被更改,包括 33 次插入0 次删除
  1. 29 0
      examples/common/blas.c
  2. 4 0
      examples/common/blas.h

+ 29 - 0
examples/common/blas.c

@@ -50,6 +50,22 @@ inline void DGEMM(char *transa, char *transb, int M, int N, int K,
 			M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);				
 }
 
+inline void SGEMV(char *transa, int M, int N, float alpha, float *A, int lda, float *X, int incX, float beta, float *Y, int incY)
+{
+	enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
+
+	cblas_sgemv(CblasColMajor, ta, M, N, alpha, A, lda,
+					X, incX, beta, Y, incY);
+}
+
+inline void DGEMV(char *transa, int M, int N, double alpha, double *A, int lda, double *X, int incX, double beta, double *Y, int incY)
+{
+	enum CBLAS_TRANSPOSE ta = (toupper(transa[0]) == 'N')?CblasNoTrans:CblasTrans;
+
+	cblas_dgemv(CblasColMajor, ta, M, N, alpha, A, lda,
+					X, incX, beta, Y, incY);
+}
+
 inline float SASUM(int N, float *X, int incX)
 {
 	return cblas_sasum(N, X, incX);
@@ -236,6 +252,19 @@ inline void DGEMM(char *transa, char *transb, int M, int N, int K,
 			 &beta, C, &ldc);	
 }
 
+
+inline void SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
+		float *X, int incX, float beta, float *Y, int incY)
+{
+	sgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
+}
+
+inline void DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
+		double *X, int incX, double beta, double *Y, int incY)
+{
+	dgemv_(transa, &M, &N, &alpha, A, &lda, X, &incX, &beta, Y, &incY);
+}
+
 inline float SASUM(int N, float *X, int incX)
 {
 	return sasum_(&N, X, &incX);

+ 4 - 0
examples/common/blas.h

@@ -27,6 +27,10 @@ void SGEMM(char *transa, char *transb, int M, int N, int K, float alpha, float *
 		float *B, int ldb, float beta, float *C, int ldc);
 void DGEMM(char *transa, char *transb, int M, int N, int K, double alpha, double *A, int lda, 
 		double *B, int ldb, double beta, double *C, int ldc);
+void SGEMV(char *transa, int M, int N, float alpha, float *A, int lda,
+		float *X, int incX, float beta, float *Y, int incY);
+void DGEMV(char *transa, int M, int N, double alpha, double *A, int lda,
+		double *X, int incX, double beta, double *Y, int incY);
 float SASUM(int N, float *X, int incX);
 double DASUM(int N, double *X, int incX);
 void SSCAL(int N, float alpha, float *X, int incX);