22 #ifndef RAJA_policy_vector_register_avx2_int64_HPP
23 #define RAJA_policy_vector_register_avx2_int64_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
38 class Register<int64_t, avx2_register>
39 :
public internal::expt::RegisterBase<Register<int64_t, avx2_register>>
43 internal::expt::RegisterBase<Register<int64_t, avx2_register>>;
45 using register_policy = avx2_register;
46 using self_type = Register<int64_t, avx2_register>;
47 using element_type = int64_t;
48 using register_type = __m256i;
50 using int_vector_type = Register<int64_t, avx2_register>;
53 register_type m_value;
56 __m256i createMask(camp::idx_t N)
const
59 return _mm256_set_epi64x(N >= 4 ? -1 : 0, N >= 3 ? -1 : 0, N >= 2 ? -1 : 0,
64 __m256i createStridedOffsets(camp::idx_t stride)
const
67 return _mm256_set_epi64x(3 * stride, 2 * stride, stride, 0);
77 RAJA_INLINE __m256i permute(__m256i x)
const
79 return _mm256_castpd_si256(_mm256_permute_pd(_mm256_castsi256_pd(x), perm));
83 static constexpr camp::idx_t s_num_elem = 4;
89 Register() : m_value(_mm256_setzero_si256()) {}
95 Register(element_type x0, element_type x1, element_type x2, element_type x3)
96 : m_value(_mm256_set_epi64x(x3, x2, x1, x0))
103 explicit Register(register_type
const& c) : m_value(c) {}
109 Register(self_type
const& c) : base_type(c), m_value(c.m_value) {}
115 self_type& operator=(self_type
const& c)
126 Register(element_type
const& c) : m_value(_mm256_set1_epi64x(c)) {}
132 constexpr register_type get_register()
const {
return m_value; }
139 self_type& load_packed(element_type
const* ptr)
141 m_value = _mm256_loadu_si256(
reinterpret_cast<__m256i const*
>(ptr));
151 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
153 m_value = _mm256_castpd_si256(_mm256_maskload_pd(
154 reinterpret_cast<double const*
>(ptr), createMask(N)));
163 self_type& load_strided(int64_t
const* ptr, camp::idx_t stride)
165 m_value = _mm256_i64gather_epi64(
reinterpret_cast<long long const*
>(ptr),
166 createStridedOffsets(stride),
167 sizeof(element_type));
177 self_type& load_strided_n(element_type
const* ptr,
181 m_value = _mm256_mask_i64gather_epi64(
182 _mm256_set1_epi64x(0),
reinterpret_cast<long long const*
>(ptr),
183 createStridedOffsets(stride), createMask(N),
sizeof(element_type));
197 self_type& gather(element_type
const* ptr, int_vector_type offsets)
199 #ifdef RAJA_ENABLE_VECTOR_STATS
200 RAJA::tensor_stats::num_vector_load_strided_n++;
203 _mm256_i64gather_epi64(
reinterpret_cast<long long const*
>(ptr),
204 offsets.get_register(),
sizeof(element_type));
218 self_type& gather_n(element_type
const* ptr,
219 int_vector_type offsets,
222 #ifdef RAJA_ENABLE_VECTOR_STATS
223 RAJA::tensor_stats::num_vector_load_strided_n++;
225 m_value = _mm256_mask_i64gather_epi64(
226 _mm256_setzero_si256(),
reinterpret_cast<long long const*
>(ptr),
227 offsets.get_register(), createMask(N),
sizeof(element_type));
236 self_type
const& store_packed(element_type* ptr)
const
238 _mm256_storeu_si256(
reinterpret_cast<__m256i*
>(ptr), m_value);
247 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
249 _mm256_maskstore_epi64(
reinterpret_cast<long long*
>(ptr), createMask(N),
259 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
261 for (camp::idx_t i = 0; i < 4; ++i)
263 ptr[i * stride] = m_value[i];
273 self_type
const& store_strided_n(element_type* ptr,
277 for (camp::idx_t i = 0; i < N; ++i)
279 ptr[i * stride] = m_value[i];
290 element_type
get(camp::idx_t i)
const
296 return _mm256_extract_epi64(m_value, 0);
298 return _mm256_extract_epi64(m_value, 1);
300 return _mm256_extract_epi64(m_value, 2);
302 return _mm256_extract_epi64(m_value, 3);
313 self_type& set(element_type value, camp::idx_t i)
319 m_value = _mm256_insert_epi64(m_value, value, 0);
322 m_value = _mm256_insert_epi64(m_value, value, 1);
325 m_value = _mm256_insert_epi64(m_value, value, 2);
328 m_value = _mm256_insert_epi64(m_value, value, 3);
338 self_type& broadcast(element_type
const& value)
340 m_value = _mm256_set1_epi64x(value);
347 self_type& copy(self_type
const& src)
349 m_value = src.m_value;
356 self_type add(self_type
const& b)
const
358 return self_type(_mm256_add_epi64(m_value, b.m_value));
364 self_type subtract(self_type
const& b)
const
366 return self_type(_mm256_sub_epi64(m_value, b.m_value));
372 self_type multiply(self_type
const& b)
const
375 return self_type(_mm256_set_epi64x(
get(3) * b.get(3),
get(2) * b.get(2),
376 get(1) * b.get(1),
get(0) * b.get(0)));
382 self_type divide(self_type
const& b)
const
385 return self_type(_mm256_set_epi64x(
get(3) / b.get(3),
get(2) / b.get(2),
386 get(1) / b.get(1),
get(0) / b.get(0)));
392 self_type divide_n(self_type
const& b, camp::idx_t N)
const
395 return self_type(_mm256_set_epi64x(
396 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
397 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
405 element_type
sum()
const
409 auto sh1 = permute<0x5>(m_value);
410 auto red1 = _mm256_add_epi64(m_value, sh1);
413 return _mm256_extract_epi64(red1, 0) + _mm256_extract_epi64(red1, 2);
421 element_type
max()
const
427 red = red < v1 ? v1 : red;
430 red = red < v2 ? v2 : red;
433 red = red < v3 ? v3 : red;
443 element_type max_n(camp::idx_t N)
const
456 red = red < v1 ? v1 : red;
461 red = red < v2 ? v2 : red;
466 red = red < v3 ? v3 : red;
477 self_type vmax(self_type a)
const
479 return self_type(_mm256_set_epi64x(
get(3) > a.get(3) ?
get(3) : a.get(3),
480 get(2) > a.get(2) ?
get(2) : a.get(2),
481 get(1) > a.get(1) ?
get(1) : a.get(1),
482 get(0) > a.get(0) ?
get(0) : a.get(0)));
490 element_type
min()
const
496 red = red > v1 ? v1 : red;
499 red = red > v2 ? v2 : red;
502 red = red > v3 ? v3 : red;
512 element_type min_n(camp::idx_t N)
const
525 red = red > v1 ? v1 : red;
530 red = red > v2 ? v2 : red;
535 red = red > v3 ? v3 : red;
546 self_type vmin(self_type a)
const
548 return self_type(_mm256_set_epi64x(
get(3) < a.get(3) ?
get(3) : a.get(3),
549 get(2) < a.get(2) ?
get(2) : a.get(2),
550 get(1) < a.get(1) ?
get(1) : a.get(1),
551 get(0) < a.get(0) ?
get(0) : a.get(0)));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155