33 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H 34 #define EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H 37 #include "../InternalHeaderCheck.h" 52 #define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASFUNC) \ 53 template <typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \ 54 struct general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \ 55 ConjugateRhs, ColMajor, 1> { \ 56 typedef gebp_traits<EIGTYPE, EIGTYPE> Traits; \ 58 static void run(Index rows, Index cols, Index depth, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \ 59 Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \ 60 level3_blocking<EIGTYPE, EIGTYPE>& , GemmParallelInfo<Index>* ) { \ 62 if (rows == 0 || cols == 0 || depth == 0) return; \ 63 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ 64 eigen_assert(resIncr == 1); \ 65 char transa, transb; \ 66 BlasIndex m, n, k, lda, ldb, ldc; \ 67 const EIGTYPE *a, *b; \ 69 MatrixX##EIGPREFIX a_tmp, b_tmp; \ 72 transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 73 transb = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 76 m = convert_index<BlasIndex>(rows); \ 77 n = convert_index<BlasIndex>(cols); \ 78 k = convert_index<BlasIndex>(depth); \ 81 lda = convert_index<BlasIndex>(lhsStride); \ 82 ldb = convert_index<BlasIndex>(rhsStride); \ 83 ldc = convert_index<BlasIndex>(resStride); \ 86 if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \ 87 Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride)); \ 88 a_tmp = lhs.conjugate(); \ 90 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 94 if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \ 95 Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride)); \ 96 b_tmp = rhs.conjugate(); \ 98 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 102 BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \ 103 (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \ 108 GEMM_SPECIALIZATION(
double, d,
double, dgemm)
109 GEMM_SPECIALIZATION(
float, f,
float, sgemm)
110 GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, zgemm)
111 GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, cgemm)
113 GEMM_SPECIALIZATION(
double, d,
double, dgemm_)
114 GEMM_SPECIALIZATION(
float, f,
float, sgemm_)
115 GEMM_SPECIALIZATION(dcomplex, cd,
double, zgemm_)
116 GEMM_SPECIALIZATION(scomplex, cf,
float, cgemm_)
121 #if EIGEN_USE_OPENBLAS_BFLOAT16 125 void sbgemm_(
const char* trans_a,
const char* trans_b,
const int* M,
const int* N,
const int* K,
const float* alpha,
126 const Eigen::bfloat16* A,
const int* lda,
const Eigen::bfloat16* B,
const int* ldb,
const float* beta,
127 float* C,
const int* ldc);
130 template <
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
bool ConjugateRhs>
131 struct general_matrix_matrix_product<
Index,
Eigen::bfloat16, LhsStorageOrder, ConjugateLhs, Eigen::bfloat16,
132 RhsStorageOrder, ConjugateRhs,
ColMajor, 1> {
133 typedef gebp_traits<Eigen::bfloat16, Eigen::bfloat16> Traits;
135 static void run(
Index rows,
Index cols,
Index depth,
const Eigen::bfloat16* lhs_,
Index lhsStride,
136 const Eigen::bfloat16* rhs_,
Index rhsStride, Eigen::bfloat16* res,
Index resIncr,
Index resStride,
137 Eigen::bfloat16 alpha, level3_blocking<Eigen::bfloat16, Eigen::bfloat16>& ,
138 GemmParallelInfo<Index>* ) {
140 if (rows == 0 || cols == 0 || depth == 0)
return;
141 EIGEN_ONLY_USED_FOR_DEBUG(resIncr);
142 eigen_assert(resIncr == 1);
144 BlasIndex m, n, k, lda, ldb, ldc;
145 const Eigen::bfloat16 *a, *b;
147 float falpha =
static_cast<float>(alpha);
148 float fbeta = float(1.0);
150 using MatrixXbf16 = Matrix<Eigen::bfloat16, Dynamic, Dynamic>;
151 MatrixXbf16 a_tmp, b_tmp;
155 transa = (LhsStorageOrder ==
RowMajor) ? ((ConjugateLhs) ?
'C' :
'T') :
'N';
156 transb = (RhsStorageOrder ==
RowMajor) ? ((ConjugateRhs) ?
'C' :
'T') :
'N';
159 m = convert_index<BlasIndex>(rows);
160 n = convert_index<BlasIndex>(cols);
161 k = convert_index<BlasIndex>(depth);
164 lda = convert_index<BlasIndex>(lhsStride);
165 ldb = convert_index<BlasIndex>(rhsStride);
166 ldc = convert_index<BlasIndex>(m);
169 if ((LhsStorageOrder ==
ColMajor) && (ConjugateLhs)) {
170 Map<const MatrixXbf16, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride));
171 a_tmp = lhs.conjugate();
173 lda = convert_index<BlasIndex>(a_tmp.outerStride());
178 if ((RhsStorageOrder ==
ColMajor) && (ConjugateRhs)) {
179 Map<const MatrixXbf16, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride));
180 b_tmp = rhs.conjugate();
182 ldb = convert_index<BlasIndex>(b_tmp.outerStride());
190 sbgemm_(&transa, &transb, &m, &n, &k, (
const float*)&numext::real_ref(falpha), a, &lda, b, &ldb,
191 (
const float*)&numext::real_ref(fbeta), r_tmp.data(), &ldc);
194 Map<MatrixXbf16, 0, OuterStride<> > result(res, m, n, OuterStride<>(resStride));
195 result = r_tmp.cast<Eigen::bfloat16>();
199 #endif // EIGEN_USE_OPENBLAS_SBGEMM 205 #endif // EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H Definition: Constants.h:318
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
Matrix< float, Dynamic, Dynamic > MatrixXf
Dynamic×Dynamic matrix of type float.
Definition: Matrix.h:478
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82
Definition: Constants.h:320