22 #ifndef RAJA_policy_vector_register_avx2_float_HPP
23 #define RAJA_policy_vector_register_avx2_float_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
38 class Register<float, avx2_register>
39 :
public internal::expt::RegisterBase<Register<float, avx2_register>>
43 internal::expt::RegisterBase<Register<float, avx2_register>>;
45 using register_policy = avx2_register;
46 using self_type = Register<float, avx2_register>;
47 using element_type = float;
48 using register_type = __m256;
50 using int_vector_type = Register<int32_t, avx2_register>;
54 register_type m_value;
57 __m256i createMask(camp::idx_t N)
const
60 return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
61 N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
62 N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
66 __m256i createStridedOffsets(camp::idx_t stride)
const
69 return _mm256_set_epi32(7 * stride, 6 * stride, 5 * stride, 4 * stride,
70 3 * stride, 2 * stride, stride, 0);
74 __m256i createPermute1(camp::idx_t N)
const
77 return _mm256_set_epi32(N >= 7 ? 6 : 0, N >= 8 ? 7 : 0, N >= 5 ? 4 : 0,
78 N >= 6 ? 5 : 0, N >= 3 ? 2 : 0, N >= 4 ? 3 : 0,
79 N >= 1 ? 0 : 0, N >= 2 ? 1 : 0);
83 __m256i createPermute2(camp::idx_t N)
const
86 return _mm256_set_epi32(N >= 6 ? 5 : 0, N >= 5 ? 4 : 0, N >= 8 ? 7 : 0,
87 N >= 7 ? 6 : 0, N >= 2 ? 1 : 0, N >= 1 ? 0 : 0,
88 N >= 4 ? 3 : 0, N >= 2 ? 2 : 0);
92 static constexpr camp::idx_t s_num_elem = 8;
98 Register() : m_value(_mm256_setzero_ps()) {}
104 Register(element_type x0,
112 : m_value(_mm256_set_ps(x7, x6, x5, x4, x3, x2, x1, x0))
119 explicit Register(register_type
const& c) : m_value(c) {}
125 Register(self_type
const& c) : base_type(c), m_value(c.m_value) {}
131 self_type& operator=(self_type
const& c)
142 Register(element_type
const& c) : m_value(_mm256_set1_ps(c)) {}
148 constexpr register_type get_register()
const {
return m_value; }
155 self_type& load_packed(element_type
const* ptr)
157 m_value = _mm256_loadu_ps(ptr);
167 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
169 m_value = _mm256_maskload_ps(ptr, createMask(N));
178 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
180 m_value = _mm256_i32gather_ps(ptr, createStridedOffsets(stride),
181 sizeof(element_type));
191 self_type& load_strided_n(element_type
const* ptr,
195 m_value = _mm256_mask_i32gather_ps(
196 _mm256_setzero_ps(), ptr, createStridedOffsets(stride),
197 _mm256_castsi256_ps(createMask(N)),
sizeof(element_type));
206 self_type
const& store_packed(element_type* ptr)
const
208 _mm256_storeu_ps(ptr, m_value);
217 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
219 _mm256_maskstore_ps(ptr, createMask(N), m_value);
228 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
230 for (camp::idx_t i = 0; i < 8; ++i)
232 ptr[i * stride] = m_value[i];
242 self_type
const& store_strided_n(element_type* ptr,
246 for (camp::idx_t i = 0; i < N; ++i)
248 ptr[i * stride] = m_value[i];
259 element_type
get(camp::idx_t i)
const {
return m_value[i]; }
267 self_type& set(element_type value, camp::idx_t i)
276 self_type& broadcast(element_type
const& value)
278 m_value = _mm256_set1_ps(value);
285 self_type& copy(self_type
const& src)
287 m_value = src.m_value;
294 self_type add(self_type
const& b)
const
296 return self_type(_mm256_add_ps(m_value, b.m_value));
302 self_type subtract(self_type
const& b)
const
304 return self_type(_mm256_sub_ps(m_value, b.m_value));
310 self_type multiply(self_type
const& b)
const
312 return self_type(_mm256_mul_ps(m_value, b.m_value));
318 self_type divide(self_type
const& b)
const
320 return self_type(_mm256_div_ps(m_value, b.m_value));
326 self_type divide_n(self_type
const& b, camp::idx_t N)
const
329 return self_type(_mm256_set_ps(
330 N >= 8 ?
get(7) / b.get(7) : 0, N >= 7 ?
get(6) / b.get(6) : 0,
331 N >= 6 ?
get(5) / b.get(5) : 0, N >= 5 ?
get(4) / b.get(4) : 0,
332 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
333 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
341 self_type multiply_add(self_type
const& b, self_type
const& c)
const
343 return self_type(_mm256_fmadd_ps(m_value, b.m_value, c.m_value));
349 self_type multiply_subtract(self_type
const& b, self_type
const& c)
const
351 return self_type(_mm256_fmsub_ps(m_value, b.m_value, c.m_value));
360 element_type
sum()
const
363 auto sh1 = _mm256_permute_ps(m_value, 0xB1);
364 auto red1 = _mm256_add_ps(m_value, sh1);
367 auto sh2 = _mm256_permute_ps(red1, 0x4E);
368 auto red2 = _mm256_add_ps(red1, sh2);
370 return red2[0] + red2[4];
378 element_type
max()
const
382 auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(8));
383 auto red1 = _mm256_max_ps(m_value, sh1);
386 auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(8));
387 auto red2 = _mm256_max_ps(red1, sh2);
389 return std::max<element_type>(red2[0], red2[4]);
397 element_type max_n(camp::idx_t N)
const
410 return std::max<element_type>(m_value[0], m_value[1]);
414 auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(N));
415 auto red1 = _mm256_max_ps(m_value, sh1);
419 return std::max<element_type>(red1[0], m_value[2]);
423 return std::max<element_type>(red1[0], red1[2]);
427 auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(N));
428 auto red2 = _mm256_max_ps(red1, sh2);
430 return std::max<element_type>(red2[0], red2[4]);
438 self_type vmax(self_type a)
const
440 return self_type(_mm256_max_ps(m_value, a.m_value));
448 element_type
min()
const
452 auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(8));
453 auto red1 = _mm256_min_ps(m_value, sh1);
456 auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(8));
457 auto red2 = _mm256_min_ps(red1, sh2);
459 return std::min<element_type>(red2[0], red2[4]);
467 element_type min_n(camp::idx_t N)
const
480 return std::min<element_type>(m_value[0], m_value[1]);
484 auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(N));
485 auto red1 = _mm256_min_ps(m_value, sh1);
489 return std::min<element_type>(red1[0], m_value[2]);
493 return std::min<element_type>(red1[0], red1[2]);
497 auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(N));
498 auto red2 = _mm256_min_ps(red1, sh2);
500 return std::min<element_type>(red2[0], red2[4]);
508 self_type vmin(self_type a)
const
510 return self_type(_mm256_min_ps(m_value, a.m_value));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155