33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H 37 #include "../InternalHeaderCheck.h" 43 template <
typename Scalar,
typename Index,
int Mode,
bool LhsIsTriangular,
int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
int ResStorageOrder>
45 struct product_triangular_matrix_matrix_trmm
46 : product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs,
47 RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
50 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 51 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \ 52 struct product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \ 53 RhsStorageOrder, ConjugateRhs, ColMajor, 1, Specialized> { \ 54 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride, \ 55 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, \ 56 Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) { \ 57 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ 58 eigen_assert(resIncr == 1); \ 59 product_triangular_matrix_matrix_trmm<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \ 60 RhsStorageOrder, ConjugateRhs, ColMajor>::run(_rows, _cols, _depth, _lhs, \ 61 lhsStride, _rhs, rhsStride, \ 62 res, resStride, alpha, \ 67 EIGEN_BLAS_TRMM_SPECIALIZE(
double,
true)
68 EIGEN_BLAS_TRMM_SPECIALIZE(
double, false)
69 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
70 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
71 EIGEN_BLAS_TRMM_SPECIALIZE(
float, true)
72 EIGEN_BLAS_TRMM_SPECIALIZE(
float, false)
73 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
74 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
77 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ 78 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \ 79 struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, \ 80 RhsStorageOrder, ConjugateRhs, ColMajor> { \ 82 IsLower = (Mode & Lower) == Lower, \ 83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \ 84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \ 85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \ 86 LowUp = IsLower ? Lower : Upper, \ 87 conjA = ((LhsStorageOrder == ColMajor) && ConjugateLhs) ? 1 : 0 \ 90 static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \ 91 Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \ 92 level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \ 93 if (_rows == 0 || _cols == 0 || _depth == 0) return; \ 94 Index diagSize = (std::min)(_rows, _depth); \ 95 Index rows = IsLower ? _rows : diagSize; \ 96 Index depth = IsLower ? diagSize : _depth; \ 99 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 100 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 103 if (rows != depth) { \ 107 if (((nthr == 1) && (((std::max)(rows, depth) - diagSize) / (double)diagSize < 0.5))) { \ 109 product_triangular_matrix_matrix<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, RhsStorageOrder, \ 110 ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, _depth, _lhs, \ 111 lhsStride, _rhs, rhsStride, res, \ 112 1, resStride, alpha, blocking); \ 116 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs, rows, depth, OuterStride<>(lhsStride)); \ 117 MatrixLhs aa_tmp = lhsMap.template triangularView<Mode>(); \ 118 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 119 gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \ 121 general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \ 122 ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, aa_tmp.data(), aStride, \ 123 _rhs, rhsStride, res, 1, resStride, alpha, \ 130 char side = 'L', transa, uplo, diag = 'N'; \ 133 BlasIndex m, n, lda, ldb; \ 136 m = convert_index<BlasIndex>(diagSize); \ 137 n = convert_index<BlasIndex>(cols); \ 140 transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 143 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \ 144 MatrixX##EIGPREFIX b_tmp; \ 147 b_tmp = rhs.conjugate(); \ 151 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 154 uplo = IsLower ? 'L' : 'U'; \ 155 if (LhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 157 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \ 160 if ((conjA != 0) || (SetDiag == 0)) { \ 162 a_tmp = lhs.conjugate(); \ 166 a_tmp.diagonal().setZero(); \ 167 else if (IsUnitDiag) \ 168 a_tmp.diagonal().setOnes(); \ 170 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 173 lda = convert_index<BlasIndex>(lhsStride); \ 177 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \ 178 &lda, (BLASTYPE*)b, &ldb); \ 181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \ 182 res_tmp = res_tmp + b_tmp; \ 187 EIGEN_BLAS_TRMM_L(
double,
double, d, dtrmm)
188 EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
189 EIGEN_BLAS_TRMM_L(
float,
float, f, strmm)
190 EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
192 EIGEN_BLAS_TRMM_L(
double,
double, d, dtrmm_)
193 EIGEN_BLAS_TRMM_L(dcomplex,
double, cd, ztrmm_)
194 EIGEN_BLAS_TRMM_L(
float,
float, f, strmm_)
195 EIGEN_BLAS_TRMM_L(scomplex,
float, cf, ctrmm_)
199 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \ 200 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \ 201 struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \ 202 RhsStorageOrder, ConjugateRhs, ColMajor> { \ 204 IsLower = (Mode & Lower) == Lower, \ 205 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \ 206 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \ 207 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \ 208 LowUp = IsLower ? Lower : Upper, \ 209 conjA = ((RhsStorageOrder == ColMajor) && ConjugateRhs) ? 1 : 0 \ 212 static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \ 213 Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \ 214 level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \ 215 if (_rows == 0 || _cols == 0 || _depth == 0) return; \ 216 Index diagSize = (std::min)(_cols, _depth); \ 217 Index rows = _rows; \ 218 Index depth = IsLower ? _depth : diagSize; \ 219 Index cols = IsLower ? diagSize : _cols; \ 221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 222 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 225 if (cols != depth) { \ 228 if ((nthr == 1) && (((std::max)(cols, depth) - diagSize) / (double)diagSize < 0.5)) { \ 230 product_triangular_matrix_matrix<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \ 231 RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, \ 240 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs, depth, cols, OuterStride<>(rhsStride)); \ 241 MatrixRhs aa_tmp = rhsMap.template triangularView<Mode>(); \ 242 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \ 243 gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \ 245 general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \ 246 ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, _lhs, lhsStride, \ 247 aa_tmp.data(), aStride, res, 1, resStride, \ 248 alpha, gemm_blocking, 0); \ 254 char side = 'R', transa, uplo, diag = 'N'; \ 257 BlasIndex m, n, lda, ldb; \ 260 m = convert_index<BlasIndex>(rows); \ 261 n = convert_index<BlasIndex>(diagSize); \ 264 transa = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 267 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \ 268 MatrixX##EIGPREFIX b_tmp; \ 271 b_tmp = lhs.conjugate(); \ 275 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \ 278 uplo = IsLower ? 'L' : 'U'; \ 279 if (RhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 281 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \ 284 if ((conjA != 0) || (SetDiag == 0)) { \ 286 a_tmp = rhs.conjugate(); \ 290 a_tmp.diagonal().setZero(); \ 291 else if (IsUnitDiag) \ 292 a_tmp.diagonal().setOnes(); \ 294 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \ 297 lda = convert_index<BlasIndex>(rhsStride); \ 301 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \ 302 &lda, (BLASTYPE*)b, &ldb); \ 305 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \ 306 res_tmp = res_tmp + b_tmp; \ 311 EIGEN_BLAS_TRMM_R(
double,
double, d, dtrmm)
312 EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
313 EIGEN_BLAS_TRMM_R(
float,
float, f, strmm)
314 EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
316 EIGEN_BLAS_TRMM_R(
double,
double, d, dtrmm_)
317 EIGEN_BLAS_TRMM_R(dcomplex,
double, cd, ztrmm_)
318 EIGEN_BLAS_TRMM_R(
float,
float, f, strmm_)
319 EIGEN_BLAS_TRMM_R(scomplex,
float, cf, ctrmm_)
325 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82