$darkmode
Eigen  5.0.1-dev
TypeCasting.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TYPE_CASTING_AVX512_H
11 #define EIGEN_TYPE_CASTING_AVX512_H
12 
13 // IWYU pragma: private
14 #include "../../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 template <>
21 struct type_casting_traits<float, bool> : vectorized_type_casting_traits<float, bool> {};
22 template <>
23 struct type_casting_traits<bool, float> : vectorized_type_casting_traits<bool, float> {};
24 
25 template <>
26 struct type_casting_traits<float, int> : vectorized_type_casting_traits<float, int> {};
27 template <>
28 struct type_casting_traits<int, float> : vectorized_type_casting_traits<int, float> {};
29 
30 template <>
31 struct type_casting_traits<float, double> : vectorized_type_casting_traits<float, double> {};
32 template <>
33 struct type_casting_traits<double, float> : vectorized_type_casting_traits<double, float> {};
34 
35 template <>
36 struct type_casting_traits<double, int> : vectorized_type_casting_traits<double, int> {};
37 template <>
38 struct type_casting_traits<int, double> : vectorized_type_casting_traits<int, double> {};
39 
40 template <>
41 struct type_casting_traits<double, int64_t> : vectorized_type_casting_traits<double, int64_t> {};
42 template <>
43 struct type_casting_traits<int64_t, double> : vectorized_type_casting_traits<int64_t, double> {};
44 
45 template <>
46 struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
47 template <>
48 struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
49 
50 template <>
51 struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
52 template <>
53 struct type_casting_traits<float, bfloat16> : vectorized_type_casting_traits<float, bfloat16> {};
54 
55 template <>
56 EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) {
57  __mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
58  return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
59 }
60 
61 template <>
62 EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
63  return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(a), _mm512_set1_epi32(1)));
64 }
65 
66 template <>
67 EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
68  return _mm512_cvttps_epi32(a);
69 }
70 
71 template <>
72 EIGEN_STRONG_INLINE Packet8d pcast<Packet16f, Packet8d>(const Packet16f& a) {
73  return _mm512_cvtps_pd(_mm512_castps512_ps256(a));
74 }
75 
76 template <>
77 EIGEN_STRONG_INLINE Packet8d pcast<Packet8f, Packet8d>(const Packet8f& a) {
78  return _mm512_cvtps_pd(a);
79 }
80 
81 template <>
82 EIGEN_STRONG_INLINE Packet8l pcast<Packet8d, Packet8l>(const Packet8d& a) {
83 #if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
84  return _mm512_cvttpd_epi64(a);
85 #else
86  constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits<double>::digits - 1,
87  kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1;
88 
89  const __m512i cst_one = _mm512_set1_epi64(1);
90  const __m512i cst_total_bits = _mm512_set1_epi64(kTotalBits);
91  const __m512i cst_bias = _mm512_set1_epi64(kBias);
92 
93  __m512i a_bits = _mm512_castpd_si512(a);
94  // shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
95  __m512i biased_e = _mm512_srli_epi64(_mm512_slli_epi64(a_bits, 1), kMantissaBits + 1);
96  __m512i e = _mm512_sub_epi64(biased_e, cst_bias);
97 
98  // shift to the left by kExponentBits + 1 to clear the sign and exponent bits
99  __m512i shifted_mantissa = _mm512_slli_epi64(a_bits, kExponentBits + 1);
100  // shift to the right by kTotalBits - e to convert the significand to an integer
101  __m512i result_significand = _mm512_srlv_epi64(shifted_mantissa, _mm512_sub_epi64(cst_total_bits, e));
102 
103  // add the implied bit
104  __m512i result_exponent = _mm512_sllv_epi64(cst_one, e);
105  // e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
106  __m512i result = _mm512_add_epi64(result_significand, result_exponent);
107  // handle negative arguments
108  __mmask8 sign_mask = _mm512_cmplt_epi64_mask(a_bits, _mm512_setzero_si512());
109  result = _mm512_mask_sub_epi64(result, sign_mask, _mm512_setzero_si512(), result);
110  return result;
111 #endif
112 }
113 
114 template <>
115 EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
116  return _mm512_cvtepi32_ps(a);
117 }
118 
119 template <>
120 EIGEN_STRONG_INLINE Packet8d pcast<Packet16i, Packet8d>(const Packet16i& a) {
121  return _mm512_cvtepi32_pd(_mm512_castsi512_si256(a));
122 }
123 
124 template <>
125 EIGEN_STRONG_INLINE Packet8d pcast<Packet8i, Packet8d>(const Packet8i& a) {
126  return _mm512_cvtepi32_pd(a);
127 }
128 
129 template <>
130 EIGEN_STRONG_INLINE Packet8d pcast<Packet8l, Packet8d>(const Packet8l& a) {
131 #if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
132  return _mm512_cvtepi64_pd(a);
133 #else
134  EIGEN_ALIGN64 int64_t aux[8];
135  pstore(aux, a);
136  return _mm512_set_pd(static_cast<double>(aux[7]), static_cast<double>(aux[6]), static_cast<double>(aux[5]),
137  static_cast<double>(aux[4]), static_cast<double>(aux[3]), static_cast<double>(aux[2]),
138  static_cast<double>(aux[1]), static_cast<double>(aux[0]));
139 #endif
140 }
141 
142 template <>
143 EIGEN_STRONG_INLINE Packet16f pcast<Packet8d, Packet16f>(const Packet8d& a, const Packet8d& b) {
144  return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b));
145 }
146 
147 template <>
148 EIGEN_STRONG_INLINE Packet16i pcast<Packet8d, Packet16i>(const Packet8d& a, const Packet8d& b) {
149  return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
150 }
151 
152 template <>
153 EIGEN_STRONG_INLINE Packet8i pcast<Packet8d, Packet8i>(const Packet8d& a) {
154  return _mm512_cvtpd_epi32(a);
155 }
156 template <>
157 EIGEN_STRONG_INLINE Packet8f pcast<Packet8d, Packet8f>(const Packet8d& a) {
158  return _mm512_cvtpd_ps(a);
159 }
160 
161 template <>
162 EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
163  return _mm512_castps_si512(a);
164 }
165 
166 template <>
167 EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
168  return _mm512_castsi512_ps(a);
169 }
170 
171 template <>
172 EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
173  return _mm512_castps_pd(a);
174 }
175 
176 template <>
177 EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8l>(const Packet8l& a) {
178  return _mm512_castsi512_pd(a);
179 }
180 
181 template <>
182 EIGEN_STRONG_INLINE Packet8l preinterpret<Packet8l, Packet8d>(const Packet8d& a) {
183  return _mm512_castpd_si512(a);
184 }
185 
186 template <>
187 EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
188  return _mm512_castpd_ps(a);
189 }
190 
191 template <>
192 EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
193  return _mm512_castps512_ps256(a);
194 }
195 
196 template <>
197 EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
198  return _mm512_castps512_ps128(a);
199 }
200 
201 template <>
202 EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
203  return _mm512_castpd512_pd256(a);
204 }
205 
206 template <>
207 EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
208  return _mm512_castpd512_pd128(a);
209 }
210 
211 template <>
212 EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
213  return _mm512_castps256_ps512(a);
214 }
215 
216 template <>
217 EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
218  return _mm512_castps128_ps512(a);
219 }
220 
221 template <>
222 EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
223  return _mm512_castpd256_pd512(a);
224 }
225 
226 template <>
227 EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
228  return _mm512_castpd128_pd512(a);
229 }
230 
231 template <>
232 EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet16i>(const Packet16i& a) {
233  return _mm512_castsi512_si256(a);
234 }
235 template <>
236 EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i& a) {
237  return _mm512_castsi512_si128(a);
238 }
239 
240 #ifndef EIGEN_VECTORIZE_AVX512FP16
241 template <>
242 EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
243  return _mm256_castsi256_si128(a);
244 }
245 
246 template <>
247 EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
248  return half2float(a);
249 }
250 
251 template <>
252 EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
253  return float2half(a);
254 }
255 
256 #endif
257 
258 template <>
259 EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
260  return _mm256_castsi256_si128(a);
261 }
262 
263 template <>
264 EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
265  return Bf16ToF32(a);
266 }
267 
268 template <>
269 EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
270  return F32ToBf16(a);
271 }
272 
273 } // end namespace internal
274 
275 } // end namespace Eigen
276 
277 #endif // EIGEN_TYPE_CASTING_AVX512_H
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1