22 #ifndef RAJA_policy_vector_register_avx2_int32_HPP
23 #define RAJA_policy_vector_register_avx2_int32_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
39 class Register<int32_t, avx2_register>
40 :
public internal::expt::RegisterBase<Register<int32_t, avx2_register>>
44 internal::expt::RegisterBase<Register<int32_t, avx2_register>>;
46 using register_policy = avx2_register;
47 using self_type = Register<int32_t, avx2_register>;
48 using element_type = int32_t;
49 using register_type = __m256i;
51 using int_vector_type = Register<int32_t, avx2_register>;
55 register_type m_value;
58 __m256i createMask(camp::idx_t N)
const
61 return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
62 N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
63 N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
67 __m256i createStridedOffsets(camp::idx_t stride)
const
70 return _mm256_set_epi32(7 * stride, 6 * stride, 5 * stride, 4 * stride,
71 3 * stride, 2 * stride, stride, 0);
75 __m256i createPermute1(camp::idx_t N)
const
78 return _mm256_set_epi32(N >= 7 ? 6 : 0, N >= 8 ? 7 : 0, N >= 5 ? 4 : 0,
79 N >= 6 ? 5 : 0, N >= 3 ? 2 : 0, N >= 4 ? 3 : 0,
80 N >= 1 ? 0 : 0, N >= 2 ? 1 : 0);
84 __m256i createPermute2(camp::idx_t N)
const
87 return _mm256_set_epi32(N >= 6 ? 5 : 0, N >= 5 ? 4 : 0, N >= 8 ? 7 : 0,
88 N >= 7 ? 6 : 0, N >= 2 ? 1 : 0, N >= 1 ? 0 : 0,
89 N >= 4 ? 3 : 0, N >= 2 ? 2 : 0);
93 static constexpr camp::idx_t s_num_elem = 8;
99 Register() : m_value(_mm256_setzero_si256()) {}
105 Register(element_type x0,
113 : m_value(_mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0))
120 explicit Register(register_type
const& c) : m_value(c) {}
126 Register(self_type
const& c) : base_type(c), m_value(c.m_value) {}
132 self_type& operator=(self_type
const& c)
143 Register(element_type
const& c) : m_value(_mm256_set1_epi32(c)) {}
149 constexpr register_type get_register()
const {
return m_value; }
156 self_type& load_packed(element_type
const* ptr)
158 m_value = _mm256_loadu_si256((__m256i
const*)ptr);
168 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
170 m_value = _mm256_maskload_epi32(ptr, createMask(N));
179 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
181 m_value = _mm256_i32gather_epi32(ptr, createStridedOffsets(stride),
182 sizeof(element_type));
192 self_type& load_strided_n(element_type
const* ptr,
196 m_value = _mm256_mask_i32gather_epi32(_mm256_setzero_si256(), ptr,
197 createStridedOffsets(stride),
198 createMask(N),
sizeof(element_type));
207 self_type
const& store_packed(element_type* ptr)
const
209 _mm256_storeu_si256(
reinterpret_cast<__m256i*
>(ptr), m_value);
218 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
220 _mm256_maskstore_epi32(ptr, createMask(N), m_value);
229 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
231 for (camp::idx_t i = 0; i < 8; ++i)
233 ptr[i * stride] =
get(i);
243 self_type
const& store_strided_n(element_type* ptr,
247 for (camp::idx_t i = 0; i < N; ++i)
249 ptr[i * stride] =
get(i);
260 element_type
get(camp::idx_t i)
const
266 return _mm256_extract_epi32(m_value, 0);
268 return _mm256_extract_epi32(m_value, 1);
270 return _mm256_extract_epi32(m_value, 2);
272 return _mm256_extract_epi32(m_value, 3);
274 return _mm256_extract_epi32(m_value, 4);
276 return _mm256_extract_epi32(m_value, 5);
278 return _mm256_extract_epi32(m_value, 6);
280 return _mm256_extract_epi32(m_value, 7);
291 self_type& set(element_type value, camp::idx_t i)
297 m_value = _mm256_insert_epi32(m_value, value, 0);
300 m_value = _mm256_insert_epi32(m_value, value, 1);
303 m_value = _mm256_insert_epi32(m_value, value, 2);
306 m_value = _mm256_insert_epi32(m_value, value, 3);
309 m_value = _mm256_insert_epi32(m_value, value, 4);
312 m_value = _mm256_insert_epi32(m_value, value, 5);
315 m_value = _mm256_insert_epi32(m_value, value, 6);
318 m_value = _mm256_insert_epi32(m_value, value, 7);
328 self_type& broadcast(element_type
const& value)
330 m_value = _mm256_set1_epi32(value);
337 self_type& copy(self_type
const& src)
339 m_value = src.m_value;
346 self_type add(self_type
const& b)
const
348 return self_type(_mm256_add_epi32(m_value, b.m_value));
354 self_type subtract(self_type
const& b)
const
356 return self_type(_mm256_sub_epi32(m_value, b.m_value));
362 self_type multiply(self_type
const& b)
const
370 auto prod_even = _mm256_mul_epi32(m_value, b.m_value);
373 auto sh_a = _mm256_castps_si256(
374 _mm256_permute_ps(_mm256_castsi256_ps(m_value), 0xB1));
376 auto sh_b = _mm256_castps_si256(
377 _mm256_permute_ps(_mm256_castsi256_ps(b.m_value), 0xB1));
380 auto prod_odd = _mm256_mul_epi32(sh_a, sh_b);
383 auto sh_odd = _mm256_castps_si256(
384 _mm256_permute_ps(_mm256_castsi256_ps(prod_odd), 0xB1));
386 return self_type(_mm256_blend_epi32(prod_even, sh_odd, 0xAA));
392 self_type divide(self_type
const& b)
const
395 return self_type(_mm256_set_epi32(
get(7) / b.get(7),
get(6) / b.get(6),
396 get(5) / b.get(5),
get(4) / b.get(4),
397 get(3) / b.get(3),
get(2) / b.get(2),
398 get(1) / b.get(1),
get(0) / b.get(0)));
404 self_type divide_n(self_type
const& b, camp::idx_t N)
const
407 return self_type(_mm256_set_epi32(
408 N >= 8 ?
get(7) / b.get(7) : 0, N >= 7 ?
get(6) / b.get(6) : 0,
409 N >= 6 ?
get(5) / b.get(5) : 0, N >= 5 ?
get(4) / b.get(4) : 0,
410 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
411 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
419 element_type
sum()
const
422 auto sh1 = _mm256_castps_si256(
423 _mm256_permute_ps(_mm256_castsi256_ps(m_value), 0xB1));
424 auto red1 = _mm256_add_epi32(m_value, sh1);
429 _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(red1), 0x4E));
430 auto red2 = _mm256_add_epi32(red1, sh2);
432 return _mm256_extract_epi32(red2, 0) + _mm256_extract_epi32(red2, 4);
440 element_type
max()
const
444 auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(8));
445 auto red1 = _mm256_max_epi32(m_value, sh1);
448 auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(8));
449 auto red2 = _mm256_max_epi32(red1, sh2);
451 return std::max<element_type>(_mm256_extract_epi32(red2, 0),
452 _mm256_extract_epi32(red2, 4));
460 element_type max_n(camp::idx_t N)
const
474 return std::max<element_type>(
get(0),
get(1));
478 auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(N));
479 auto red1 = _mm256_max_epi32(m_value, sh1);
483 return std::max<element_type>(_mm256_extract_epi32(red1, 0),
get(2));
487 return std::max<element_type>(_mm256_extract_epi32(red1, 0),
488 _mm256_extract_epi32(red1, 2));
492 auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(N));
493 auto red2 = _mm256_max_epi32(red1, sh2);
495 return std::max<element_type>(_mm256_extract_epi32(red2, 0),
496 _mm256_extract_epi32(red2, 4));
504 self_type vmax(self_type a)
const
506 return self_type(_mm256_max_epi32(m_value, a.m_value));
514 element_type
min()
const
518 auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(8));
519 auto red1 = _mm256_min_epi32(m_value, sh1);
523 auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(8));
524 auto red2 = _mm256_min_epi32(red1, sh2);
526 return std::min<element_type>(_mm256_extract_epi32(red2, 0),
527 _mm256_extract_epi32(red2, 4));
535 element_type min_n(camp::idx_t N)
const
549 return std::min<element_type>(
get(0),
get(1));
553 auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(N));
554 auto red1 = _mm256_min_epi32(m_value, sh1);
558 return std::min<element_type>(_mm256_extract_epi32(red1, 0),
get(2));
562 return std::min<element_type>(_mm256_extract_epi32(red1, 0),
563 _mm256_extract_epi32(red1, 2));
567 auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(N));
568 auto red2 = _mm256_min_epi32(red1, sh2);
570 return std::min<element_type>(_mm256_extract_epi32(red2, 0),
571 _mm256_extract_epi32(red2, 4));
579 self_type vmin(self_type a)
const
581 return self_type(_mm256_min_epi32(m_value, a.m_value));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155