33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H 34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H 37 #include "../InternalHeaderCheck.h" 49 template <
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
51 struct triangular_matrix_vector_product_trmv
52 : triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, StorageOrder, BuiltIn> {};
54 #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \ 55 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 56 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor, Specialized> { \ 57 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \ 58 Scalar* res_, Index resIncr, Scalar alpha) { \ 59 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor>::run( \ 60 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \ 63 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 64 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor, Specialized> { \ 65 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \ 66 Scalar* res_, Index resIncr, Scalar alpha) { \ 67 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor>::run( \ 68 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \ 72 EIGEN_BLAS_TRMV_SPECIALIZE(
double)
73 EIGEN_BLAS_TRMV_SPECIALIZE(
float)
74 EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
75 EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
78 #define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \ 79 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 80 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor> { \ 82 IsLower = (Mode & Lower) == Lower, \ 83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \ 84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \ 85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \ 86 LowUp = IsLower ? Lower : Upper \ 88 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \ 89 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \ 90 if (rows_ == 0 || cols_ == 0) return; \ 91 if (ConjLhs || IsZeroDiag) { \ 92 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor, BuiltIn>::run( \ 93 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \ 96 Index size = (std::min)(rows_, cols_); \ 97 Index rows = IsLower ? rows_ : size; \ 98 Index cols = IsLower ? size : cols_; \ 100 typedef VectorX##EIGPREFIX VectorRhs; \ 104 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \ 107 x_tmp = rhs.conjugate(); \ 114 char trans, uplo, diag; \ 115 BlasIndex m, n, lda, incx, incy; \ 120 n = convert_index<BlasIndex>(size); \ 121 lda = convert_index<BlasIndex>(lhsStride); \ 123 incy = convert_index<BlasIndex>(resIncr); \ 127 uplo = IsLower ? 'L' : 'U'; \ 128 diag = IsUnitDiag ? 'U' : 'N'; \ 131 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \ 134 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \ 135 (BLASTYPE*)res_, &incy); \ 137 if (size < (std::max)(rows, cols)) { \ 139 x_tmp = rhs.conjugate(); \ 144 y = res_ + size * resIncr; \ 146 m = convert_index<BlasIndex>(rows - size); \ 147 n = convert_index<BlasIndex>(size); \ 151 a = lhs_ + size * lda; \ 152 m = convert_index<BlasIndex>(size); \ 153 n = convert_index<BlasIndex>(cols - size); \ 155 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \ 156 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \ 157 (BLASTYPE*)y, &incy); \ 163 EIGEN_BLAS_TRMV_CM(
double,
double, d, d, )
164 EIGEN_BLAS_TRMV_CM(dcomplex, MKL_Complex16, cd, z, )
165 EIGEN_BLAS_TRMV_CM(
float,
float, f, s, )
166 EIGEN_BLAS_TRMV_CM(scomplex, MKL_Complex8, cf, c, )
168 EIGEN_BLAS_TRMV_CM(
double,
double, d, d, _)
169 EIGEN_BLAS_TRMV_CM(dcomplex,
double, cd, z, _)
170 EIGEN_BLAS_TRMV_CM(
float,
float, f, s, _)
171 EIGEN_BLAS_TRMV_CM(scomplex,
float, cf, c, _)
175 #define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \ 176 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ 177 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor> { \ 179 IsLower = (Mode & Lower) == Lower, \ 180 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \ 181 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \ 182 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \ 183 LowUp = IsLower ? Lower : Upper \ 185 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \ 186 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \ 187 if (rows_ == 0 || cols_ == 0) return; \ 189 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor, BuiltIn>::run( \ 190 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \ 193 Index size = (std::min)(rows_, cols_); \ 194 Index rows = IsLower ? rows_ : size; \ 195 Index cols = IsLower ? size : cols_; \ 197 typedef VectorX##EIGPREFIX VectorRhs; \ 201 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \ 204 x_tmp = rhs.conjugate(); \ 211 char trans, uplo, diag; \ 212 BlasIndex m, n, lda, incx, incy; \ 217 n = convert_index<BlasIndex>(size); \ 218 lda = convert_index<BlasIndex>(lhsStride); \ 220 incy = convert_index<BlasIndex>(resIncr); \ 223 trans = ConjLhs ? 'C' : 'T'; \ 224 uplo = IsLower ? 'U' : 'L'; \ 225 diag = IsUnitDiag ? 'U' : 'N'; \ 228 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \ 231 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \ 232 (BLASTYPE*)res_, &incy); \ 234 if (size < (std::max)(rows, cols)) { \ 236 x_tmp = rhs.conjugate(); \ 241 y = res_ + size * resIncr; \ 242 a = lhs_ + size * lda; \ 243 m = convert_index<BlasIndex>(rows - size); \ 244 n = convert_index<BlasIndex>(size); \ 249 m = convert_index<BlasIndex>(size); \ 250 n = convert_index<BlasIndex>(cols - size); \ 252 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \ 253 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \ 254 (BLASTYPE*)y, &incy); \ 260 EIGEN_BLAS_TRMV_RM(
double,
double, d, d, )
261 EIGEN_BLAS_TRMV_RM(dcomplex, MKL_Complex16, cd, z, )
262 EIGEN_BLAS_TRMV_RM(
float,
float, f, s, )
263 EIGEN_BLAS_TRMV_RM(scomplex, MKL_Complex8, cf, c, )
265 EIGEN_BLAS_TRMV_RM(
double,
double, d, d, _)
266 EIGEN_BLAS_TRMV_RM(dcomplex,
double, cd, z, _)
267 EIGEN_BLAS_TRMV_RM(
float,
float, f, s, _)
268 EIGEN_BLAS_TRMV_RM(scomplex,
float, cf, c, _)
275 #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:82