22 #ifndef RAJA_policy_vector_register_avx_int32_HPP
23 #define RAJA_policy_vector_register_avx_int32_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
38 class Register<int32_t, avx_register>
39 :
public internal::expt::RegisterBase<Register<int32_t, avx_register>>
43 internal::expt::RegisterBase<Register<int32_t, avx_register>>;
45 using register_policy = avx_register;
46 using self_type = Register<int32_t, avx_register>;
47 using element_type = int32_t;
48 using register_type = __m256i;
50 using int_vector_type = Register<int32_t, avx_register>;
54 register_type m_value;
57 __m256i createMask(camp::idx_t N)
const
60 return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
61 N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
62 N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
66 __m256i createStridedOffsets(camp::idx_t stride)
const
69 return _mm256_set_epi32(7 * stride, 6 * stride, 5 * stride, 4 * stride,
70 3 * stride, 2 * stride, stride, 0);
74 __m256i createPermute1(camp::idx_t N)
const
77 return _mm256_set_epi32(N >= 7 ? 6 : 0, N >= 8 ? 7 : 0, N >= 5 ? 4 : 0,
78 N >= 6 ? 5 : 0, N >= 3 ? 2 : 0, N >= 4 ? 3 : 0,
79 N >= 1 ? 0 : 0, N >= 2 ? 1 : 0);
83 __m256i createPermute2(camp::idx_t N)
const
86 return _mm256_set_epi32(N >= 6 ? 5 : 0, N >= 5 ? 4 : 0, N >= 8 ? 7 : 0,
87 N >= 7 ? 6 : 0, N >= 2 ? 1 : 0, N >= 1 ? 0 : 0,
88 N >= 4 ? 3 : 0, N >= 2 ? 2 : 0);
92 static constexpr camp::idx_t s_num_elem = 8;
98 Register() : base_type(), m_value(_mm256_setzero_si256()) {}
104 explicit Register(register_type
const& c) : base_type(), m_value(c) {}
110 Register(element_type x0,
118 : m_value(_mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0))
125 Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
131 self_type& operator=(self_type
const& c)
142 Register(element_type
const& c) : m_value(_mm256_set1_epi32(c)) {}
149 self_type& load_packed(element_type
const* ptr)
151 m_value = _mm256_loadu_si256((__m256i
const*)ptr);
161 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
163 m_value = _mm256_setzero_si256();
164 for (camp::idx_t i = 0; i < N; ++i)
176 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
178 for (camp::idx_t i = 0; i < 8; ++i)
180 set(ptr[i * stride], i);
191 self_type& load_strided_n(element_type
const* ptr,
195 m_value = _mm256_setzero_si256();
196 for (camp::idx_t i = 0; i < N; ++i)
198 set(ptr[i * stride], i);
208 self_type
const& store_packed(element_type* ptr)
const
210 _mm256_storeu_si256(
reinterpret_cast<__m256i*
>(ptr), m_value);
219 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
221 _mm256_maskstore_ps(
reinterpret_cast<float*
>(ptr), createMask(N),
222 reinterpret_cast<__m256
>(m_value));
231 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
233 for (camp::idx_t i = 0; i < 8; ++i)
235 ptr[i * stride] =
get(i);
245 self_type
const& store_strided_n(element_type* ptr,
249 for (camp::idx_t i = 0; i < N; ++i)
251 ptr[i * stride] =
get(i);
262 element_type
get(camp::idx_t i)
const
268 return _mm256_extract_epi32(m_value, 0);
270 return _mm256_extract_epi32(m_value, 1);
272 return _mm256_extract_epi32(m_value, 2);
274 return _mm256_extract_epi32(m_value, 3);
276 return _mm256_extract_epi32(m_value, 4);
278 return _mm256_extract_epi32(m_value, 5);
280 return _mm256_extract_epi32(m_value, 6);
282 return _mm256_extract_epi32(m_value, 7);
293 self_type& set(element_type value, camp::idx_t i)
299 m_value = _mm256_insert_epi32(m_value, value, 0);
302 m_value = _mm256_insert_epi32(m_value, value, 1);
305 m_value = _mm256_insert_epi32(m_value, value, 2);
308 m_value = _mm256_insert_epi32(m_value, value, 3);
311 m_value = _mm256_insert_epi32(m_value, value, 4);
314 m_value = _mm256_insert_epi32(m_value, value, 5);
317 m_value = _mm256_insert_epi32(m_value, value, 6);
320 m_value = _mm256_insert_epi32(m_value, value, 7);
330 self_type& broadcast(element_type
const& value)
332 m_value = _mm256_set1_epi32(value);
339 self_type& copy(self_type
const& src)
341 m_value = src.m_value;
348 self_type add(self_type
const& b)
const
353 auto low_a = _mm256_castsi256_si128(m_value);
354 auto low_b = _mm256_castsi256_si128(b.m_value);
355 auto res_low = _mm256_castsi128_si256(_mm_add_epi32(low_a, low_b));
358 auto hi_a = _mm256_extractf128_si256(m_value, 1);
359 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
360 auto res_hi = _mm_add_epi32(hi_a, hi_b);
363 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
369 self_type subtract(self_type
const& b)
const
374 auto low_a = _mm256_castsi256_si128(m_value);
375 auto low_b = _mm256_castsi256_si128(b.m_value);
376 auto res_low = _mm256_castsi128_si256(_mm_sub_epi32(low_a, low_b));
379 auto hi_a = _mm256_extractf128_si256(m_value, 1);
380 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
381 auto res_hi = _mm_sub_epi32(hi_a, hi_b);
384 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
390 self_type multiply(self_type
const& b)
const
396 auto low_a = _mm256_castsi256_si128(m_value);
397 auto low_b = _mm256_castsi256_si128(b.m_value);
399 auto res_low_even = _mm_mul_epi32(low_a, low_b);
402 auto low_a_sh = _mm_shuffle_epi32(low_a, 0xB1);
403 auto low_b_sh = _mm_shuffle_epi32(low_b, 0xB1);
404 auto res_low_odd = _mm_mul_epi32(low_a_sh, low_b_sh);
408 res_low_odd = _mm_shuffle_epi32(res_low_odd, 0xB1);
409 auto res_low = _mm256_castsi128_si256(_mm_castps_si128(_mm_blend_ps(
410 _mm_castsi128_ps(res_low_odd), _mm_castsi128_ps(res_low_even), 0x05)));
414 auto hi_a = _mm256_extractf128_si256(m_value, 1);
415 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
417 auto res_hi_even = _mm_mul_epi32(hi_a, hi_b);
420 auto hi_a_sh = _mm_shuffle_epi32(hi_a, 0xB1);
421 auto hi_b_sh = _mm_shuffle_epi32(hi_b, 0xB1);
422 auto res_hi_odd = _mm_mul_epi32(hi_a_sh, hi_b_sh);
426 res_hi_odd = _mm_shuffle_epi32(res_hi_odd, 0xB1);
427 auto res_hi = _mm_castps_si128(_mm_blend_ps(
428 _mm_castsi128_ps(res_hi_odd), _mm_castsi128_ps(res_hi_even), 0x05));
431 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
437 self_type divide(self_type
const& b)
const
440 return self_type(_mm256_set_epi32(
get(7) / b.get(7),
get(6) / b.get(6),
441 get(5) / b.get(5),
get(4) / b.get(4),
442 get(3) / b.get(3),
get(2) / b.get(2),
443 get(1) / b.get(1),
get(0) / b.get(0)));
449 self_type divide_n(self_type
const& b, camp::idx_t N)
const
452 return self_type(_mm256_set_epi32(
453 N >= 8 ?
get(7) / b.get(7) : 0, N >= 7 ?
get(6) / b.get(6) : 0,
454 N >= 6 ?
get(5) / b.get(5) : 0, N >= 5 ?
get(4) / b.get(4) : 0,
455 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
456 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
464 element_type
sum()
const
467 auto low = _mm256_castsi256_si128(m_value);
469 auto low_sh1 = _mm_shuffle_epi32(low, 0xB1);
470 auto low_red1 = _mm_add_epi32(low, low_sh1);
472 auto low_sh2 = _mm_shuffle_epi32(low_red1, 0x1B);
473 auto low_red2 = _mm_add_epi32(low_red1, low_sh2);
477 auto hi = _mm256_extractf128_si256(m_value, 1);
479 auto hi_sh1 = _mm_shuffle_epi32(hi, 0xB1);
480 auto hi_red1 = _mm_add_epi32(hi, hi_sh1);
482 auto hi_sh2 = _mm_shuffle_epi32(hi_red1, 0x1B);
483 auto hi_red2 = _mm_add_epi32(hi_red1, hi_sh2);
487 auto hi_low = _mm_add_epi32(hi_red2, low_red2);
488 return _mm_extract_epi32(hi_low, 0);
496 element_type
max()
const
504 auto low = _mm256_castsi256_si128(m_value);
506 auto low_sh1 = _mm_shuffle_epi32(low, 0xB1);
507 auto low_red1 = _mm_max_epi32(low, low_sh1);
509 auto low_sh2 = _mm_shuffle_epi32(low_red1, 0x1B);
512 auto low_red2 = _mm_max_epi32(low_red1, low_sh2);
516 auto hi = _mm256_extractf128_si256(m_value, 1);
519 auto hi_sh1 = _mm_shuffle_epi32(hi, 0xB1);
520 auto hi_red1 = _mm_max_epi32(hi, hi_sh1);
522 auto hi_sh2 = _mm_shuffle_epi32(hi_red1, 0x1B);
523 auto hi_red2 = _mm_max_epi32(hi_red1, hi_sh2);
527 auto hi_low = _mm_max_epi32(hi_red2, low_red2);
528 return _mm_extract_epi32(hi_low, 0);
536 element_type max_n(camp::idx_t N)
const
549 return _mm256_extract_epi32(m_value, 0);
553 auto low = _mm256_castsi256_si128(m_value);
555 auto low_sh1 = _mm_shuffle_epi32(low, 0xB1);
556 auto low_red1 = _mm_max_epi32(low, low_sh1);
560 return _mm_extract_epi32(low_red1, 0);
566 auto low_sh1a = _mm_shuffle_epi32(low, 0x2);
567 auto low_red1a = _mm_max_epi32(low_red1, low_sh1a);
568 return _mm_extract_epi32(low_red1a, 0);
571 auto low_sh2 = _mm_shuffle_epi32(low_red1, 0x1B);
574 auto low_red2 = _mm_max_epi32(low_red1, low_sh2);
578 return _mm_extract_epi32(low_red2, 0);
582 auto hi = _mm256_extractf128_si256(m_value, 1);
586 auto red_5 = _mm_max_epi32(low_red2, hi);
587 return _mm_extract_epi32(red_5, 0);
590 auto hi_sh1 = _mm_shuffle_epi32(hi, 0xB1);
591 auto hi_red1 = _mm_max_epi32(hi, hi_sh1);
595 auto red_6 = _mm_max_epi32(low_red2, hi_red1);
596 return _mm_extract_epi32(red_6, 0);
601 auto hi_sh7 = _mm_shuffle_epi32(hi, 0x2);
602 auto hi_red_6 = _mm_max_epi32(hi_sh7, hi_red1);
603 auto red_7 = _mm_max_epi32(low_red2, hi_red_6);
604 return _mm_extract_epi32(red_7, 0);
607 auto hi_sh2 = _mm_shuffle_epi32(hi_red1, 0x1B);
608 auto hi_red2 = _mm_max_epi32(hi_red1, hi_sh2);
612 auto hi_low = _mm_max_epi32(hi_red2, low_red2);
613 return _mm_extract_epi32(hi_low, 0);
621 self_type vmax(self_type b)
const
626 auto low_a = _mm256_castsi256_si128(m_value);
627 auto low_b = _mm256_castsi256_si128(b.m_value);
628 auto res_low = _mm256_castsi128_si256(_mm_max_epi32(low_a, low_b));
631 auto hi_a = _mm256_extractf128_si256(m_value, 1);
632 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
633 auto res_hi = _mm_max_epi32(hi_a, hi_b);
636 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
644 element_type
min()
const
651 auto low = _mm256_castsi256_si128(m_value);
653 auto low_sh1 = _mm_shuffle_epi32(low, 0xB1);
654 auto low_red1 = _mm_min_epi32(low, low_sh1);
656 auto low_sh2 = _mm_shuffle_epi32(low_red1, 0x1B);
659 auto low_red2 = _mm_min_epi32(low_red1, low_sh2);
663 auto hi = _mm256_extractf128_si256(m_value, 1);
665 auto hi_sh1 = _mm_shuffle_epi32(hi, 0xB1);
666 auto hi_red1 = _mm_min_epi32(hi, hi_sh1);
669 auto hi_sh2 = _mm_shuffle_epi32(hi_red1, 0x1B);
670 auto hi_red2 = _mm_min_epi32(hi_red1, hi_sh2);
674 auto hi_low = _mm_min_epi32(hi_red2, low_red2);
675 return _mm_extract_epi32(hi_low, 0);
683 element_type min_n(camp::idx_t N)
const
695 return _mm256_extract_epi32(m_value, 0);
699 auto low = _mm256_castsi256_si128(m_value);
701 auto low_sh1 = _mm_shuffle_epi32(low, 0xB1);
702 auto low_red1 = _mm_min_epi32(low, low_sh1);
706 return _mm_extract_epi32(low_red1, 0);
712 auto low_sh1a = _mm_shuffle_epi32(low, 0x2);
713 auto low_red1a = _mm_min_epi32(low_red1, low_sh1a);
714 return _mm_extract_epi32(low_red1a, 0);
717 auto low_sh2 = _mm_shuffle_epi32(low_red1, 0x1B);
720 auto low_red2 = _mm_min_epi32(low_red1, low_sh2);
724 return _mm_extract_epi32(low_red2, 0);
728 auto hi = _mm256_extractf128_si256(m_value, 1);
732 auto red_5 = _mm_min_epi32(low_red2, hi);
733 return _mm_extract_epi32(red_5, 0);
736 auto hi_sh1 = _mm_shuffle_epi32(hi, 0xB1);
737 auto hi_red1 = _mm_min_epi32(hi, hi_sh1);
741 auto red_6 = _mm_min_epi32(low_red2, hi_red1);
742 return _mm_extract_epi32(red_6, 0);
747 auto hi_sh7 = _mm_shuffle_epi32(hi, 0x2);
748 auto hi_red_6 = _mm_min_epi32(hi_sh7, hi_red1);
749 auto red_7 = _mm_min_epi32(low_red2, hi_red_6);
750 return _mm_extract_epi32(red_7, 0);
753 auto hi_sh2 = _mm_shuffle_epi32(hi_red1, 0x1B);
754 auto hi_red2 = _mm_min_epi32(hi_red1, hi_sh2);
758 auto hi_low = _mm_min_epi32(hi_red2, low_red2);
759 return _mm_extract_epi32(hi_low, 0);
767 self_type vmin(self_type b)
const
772 auto low_a = _mm256_castsi256_si128(m_value);
773 auto low_b = _mm256_castsi256_si128(b.m_value);
774 auto res_low = _mm256_castsi128_si256(_mm_min_epi32(low_a, low_b));
777 auto hi_a = _mm256_extractf128_si256(m_value, 1);
778 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
779 auto res_hi = _mm_min_epi32(hi_a, hi_b);
782 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155