22 #ifndef RAJA_policy_vector_register_avx_int64_HPP
23 #define RAJA_policy_vector_register_avx_int64_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
38 class Register<int64_t, avx_register>
39 :
public internal::expt::RegisterBase<Register<int64_t, avx_register>>
43 internal::expt::RegisterBase<Register<int64_t, avx_register>>;
45 using register_policy = avx_register;
46 using self_type = Register<int64_t, avx_register>;
47 using element_type = int64_t;
48 using register_type = __m256i;
50 using int_vector_type = Register<int64_t, avx_register>;
54 register_type m_value;
57 __m256i createMask(camp::idx_t N)
const
60 return _mm256_set_epi64x(N >= 4 ? -1 : 0, N >= 3 ? -1 : 0, N >= 2 ? -1 : 0,
65 __m256i createStridedOffsets(camp::idx_t stride)
const
68 return _mm256_set_epi64x(3 * stride, 2 * stride, stride, 0);
78 RAJA_INLINE __m256i permute(__m256i x)
const
80 return _mm256_castpd_si256(_mm256_permute_pd(_mm256_castsi256_pd(x), perm));
84 static constexpr camp::idx_t s_num_elem = 4;
90 Register() : base_type(), m_value(_mm256_setzero_si256()) {}
96 explicit Register(register_type
const& c) : base_type(), m_value(c) {}
102 Register(element_type x0, element_type x1, element_type x2, element_type x3)
103 : m_value(_mm256_set_epi64x(x3, x2, x1, x0))
110 Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
116 self_type& operator=(self_type
const& c)
127 Register(element_type
const& c) : m_value(_mm256_set1_epi64x(c)) {}
134 self_type& load_packed(element_type
const* ptr)
136 m_value = _mm256_loadu_si256(
reinterpret_cast<__m256i const*
>(ptr));
146 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
148 m_value = _mm256_castpd_si256(_mm256_maskload_pd(
149 reinterpret_cast<double const*
>(ptr), createMask(N)));
158 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
160 for (camp::idx_t i = 0; i < 4; ++i)
162 m_value[i] = ptr[i * stride];
173 self_type& load_strided_n(element_type
const* ptr,
177 m_value = _mm256_setzero_si256();
178 for (camp::idx_t i = 0; i < N; ++i)
180 m_value[i] = ptr[i * stride];
190 self_type
const& store_packed(element_type* ptr)
const
192 _mm256_storeu_si256(
reinterpret_cast<__m256i*
>(ptr), m_value);
201 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
203 _mm256_maskstore_pd(
reinterpret_cast<double*
>(ptr), createMask(N),
204 reinterpret_cast<__m256d
>(m_value));
213 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
215 for (camp::idx_t i = 0; i < 4; ++i)
217 ptr[i * stride] = m_value[i];
227 self_type
const& store_strided_n(element_type* ptr,
231 for (camp::idx_t i = 0; i < N; ++i)
233 ptr[i * stride] = m_value[i];
244 element_type
get(camp::idx_t i)
const
250 return _mm256_extract_epi64(m_value, 0);
252 return _mm256_extract_epi64(m_value, 1);
254 return _mm256_extract_epi64(m_value, 2);
256 return _mm256_extract_epi64(m_value, 3);
267 self_type& set(element_type value, camp::idx_t i)
273 m_value = _mm256_insert_epi64(m_value, value, 0);
276 m_value = _mm256_insert_epi64(m_value, value, 1);
279 m_value = _mm256_insert_epi64(m_value, value, 2);
282 m_value = _mm256_insert_epi64(m_value, value, 3);
292 self_type& broadcast(element_type
const& value)
294 m_value = _mm256_set1_epi64x(value);
301 self_type& copy(self_type
const& src)
303 m_value = src.m_value;
310 self_type add(self_type
const& b)
const
315 auto low_a = _mm256_castsi256_si128(m_value);
316 auto low_b = _mm256_castsi256_si128(b.m_value);
317 auto res_low = _mm256_castsi128_si256(_mm_add_epi64(low_a, low_b));
320 auto hi_a = _mm256_extractf128_si256(m_value, 1);
321 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
322 auto res_hi = _mm_add_epi64(hi_a, hi_b);
325 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
331 self_type subtract(self_type
const& b)
const
336 auto low_a = _mm256_castsi256_si128(m_value);
337 auto low_b = _mm256_castsi256_si128(b.m_value);
338 auto res_low = _mm256_castsi128_si256(_mm_sub_epi64(low_a, low_b));
341 auto hi_a = _mm256_extractf128_si256(m_value, 1);
342 auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
343 auto res_hi = _mm_sub_epi64(hi_a, hi_b);
346 return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
352 self_type multiply(self_type
const& b)
const
355 return self_type(_mm256_set_epi64x(
get(3) * b.get(3),
get(2) * b.get(2),
356 get(1) * b.get(1),
get(0) * b.get(0)));
362 self_type divide(self_type
const& b)
const
365 return self_type(_mm256_set_epi64x(
get(3) / b.get(3),
get(2) / b.get(2),
366 get(1) / b.get(1),
get(0) / b.get(0)));
372 self_type divide_n(self_type
const& b, camp::idx_t N)
const
375 return self_type(_mm256_set_epi64x(
376 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
377 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
385 element_type
sum()
const
388 auto sh1 = permute<0x5>(m_value);
391 auto low_a = _mm256_castsi256_si128(m_value);
392 auto low_b = _mm256_castsi256_si128(sh1);
393 auto res_low = _mm_add_epi64(low_a, low_b);
396 auto hi_a = _mm256_extractf128_si256(m_value, 1);
397 auto hi_b = _mm256_extractf128_si256(sh1, 1);
398 auto res_hi = _mm_add_epi64(hi_a, hi_b);
401 auto res = _mm_add_epi64(res_hi, res_low);
404 return _mm_extract_epi64(res, 0);
412 element_type
max()
const
418 red = red < v1 ? v1 : red;
421 red = red < v2 ? v2 : red;
424 red = red < v3 ? v3 : red;
434 element_type max_n(camp::idx_t N)
const
447 red = red < v1 ? v1 : red;
452 red = red < v2 ? v2 : red;
457 red = red < v3 ? v3 : red;
468 self_type vmax(self_type a)
const
470 return self_type(_mm256_set_epi64x(
get(3) > a.get(3) ?
get(3) : a.get(3),
471 get(2) > a.get(2) ?
get(2) : a.get(2),
472 get(1) > a.get(1) ?
get(1) : a.get(1),
473 get(0) > a.get(0) ?
get(0) : a.get(0)));
481 element_type
min()
const
488 red = red > v1 ? v1 : red;
491 red = red > v2 ? v2 : red;
494 red = red > v3 ? v3 : red;
504 element_type min_n(camp::idx_t N)
const
517 red = red > v1 ? v1 : red;
522 red = red > v2 ? v2 : red;
527 red = red > v3 ? v3 : red;
538 self_type vmin(self_type a)
const
540 return self_type(_mm256_set_epi64x(
get(3) < a.get(3) ?
get(3) : a.get(3),
541 get(2) < a.get(2) ?
get(2) : a.get(2),
542 get(1) < a.get(1) ?
get(1) : a.get(1),
543 get(0) < a.get(0) ?
get(0) : a.get(0)));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155