22 #ifndef RAJA_policy_vector_register_avx_float_HPP
23 #define RAJA_policy_vector_register_avx_float_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
39 class Register<float, avx_register>
40 :
public internal::expt::RegisterBase<Register<float, avx_register>>
43 using base_type = internal::expt::RegisterBase<Register<float, avx_register>>;
45 using register_policy = avx_register;
46 using self_type = Register<float, avx_register>;
47 using element_type = float;
48 using register_type = __m256;
50 using int_vector_type = Register<int32_t, avx_register>;
54 register_type m_value;
57 __m256i createMask(camp::idx_t N)
const
60 return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
61 N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
62 N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
66 static constexpr camp::idx_t s_num_elem = 8;
72 Register() : base_type(), m_value(_mm256_setzero_ps()) {}
78 explicit Register(register_type
const& c) : base_type(), m_value(c) {}
84 Register(element_type x0,
92 : m_value(_mm256_set_ps(x7, x6, x5, x4, x3, x2, x1, x0))
99 Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
105 self_type& operator=(self_type
const& c)
116 Register(element_type
const& c) : m_value(_mm256_set1_ps(c)) {}
123 self_type& load_packed(element_type
const* ptr)
125 m_value = _mm256_loadu_ps(ptr);
135 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
137 m_value = _mm256_maskload_ps(ptr, createMask(N));
146 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
148 for (camp::idx_t i = 0; i < 8; ++i)
150 m_value[i] = ptr[i * stride];
161 self_type& load_strided_n(element_type
const* ptr,
165 m_value = _mm256_setzero_ps();
166 for (camp::idx_t i = 0; i < N; ++i)
168 m_value[i] = ptr[i * stride];
178 self_type
const& store_packed(element_type* ptr)
const
180 _mm256_storeu_ps(ptr, m_value);
189 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
191 _mm256_maskstore_ps(ptr, createMask(N), m_value);
200 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
202 for (camp::idx_t i = 0; i < 8; ++i)
204 ptr[i * stride] = m_value[i];
214 self_type
const& store_strided_n(element_type* ptr,
218 for (camp::idx_t i = 0; i < N; ++i)
220 ptr[i * stride] = m_value[i];
231 element_type
get(camp::idx_t i)
const {
return m_value[i]; }
239 self_type& set(element_type value, camp::idx_t i)
248 self_type& broadcast(element_type
const& value)
250 m_value = _mm256_set1_ps(value);
257 self_type& copy(self_type
const& src)
259 m_value = src.m_value;
266 self_type add(self_type
const& b)
const
268 return self_type(_mm256_add_ps(m_value, b.m_value));
274 self_type subtract(self_type
const& b)
const
276 return self_type(_mm256_sub_ps(m_value, b.m_value));
282 self_type multiply(self_type
const& b)
const
284 return self_type(_mm256_mul_ps(m_value, b.m_value));
290 self_type divide(self_type
const& b)
const
292 return self_type(_mm256_div_ps(m_value, b.m_value));
298 self_type divide_n(self_type
const& b, camp::idx_t N)
const
301 return self_type(_mm256_set_ps(
302 N >= 8 ?
get(7) / b.get(7) : 0, N >= 7 ?
get(6) / b.get(6) : 0,
303 N >= 6 ?
get(5) / b.get(5) : 0, N >= 5 ?
get(4) / b.get(4) : 0,
304 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
305 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
313 element_type
sum()
const
316 auto sh1 = _mm256_permute_ps(m_value, 0xB1);
317 auto red1 = _mm256_add_ps(m_value, sh1);
320 auto sh2 = _mm256_permute_ps(red1, 0x4E);
321 auto red2 = _mm256_add_ps(red1, sh2);
323 return red2[0] + red2[4];
331 element_type
max()
const
334 auto sh1 = _mm256_permute_ps(m_value, 0xB1);
335 auto red1 = _mm256_max_ps(m_value, sh1);
338 auto sh2 = _mm256_permute_ps(red1, 0x4E);
339 auto red2 = _mm256_max_ps(red1, sh2);
342 return RAJA::max<element_type>(red2[0], red2[4]);
350 element_type max_n(camp::idx_t N)
const
363 return RAJA::max<element_type>(m_value[0], m_value[1]);
367 auto sh1 = _mm256_permute_ps(m_value, 0xB1);
372 sh1 = _mm256_blend_ps(sh1, m_value, 0x40);
375 auto red1 = _mm256_max_ps(m_value, sh1);
380 return RAJA::max<element_type>(red1[0], m_value[2]);
385 auto sh2 = _mm256_permute_ps(red1, 0x4E);
386 auto red2 = _mm256_max_ps(red1, sh2);
394 return RAJA::max<element_type>(red2[0], m_value[4]);
398 return RAJA::max<element_type>(red2[0], red1[4]);
402 return RAJA::max<element_type>(red2[0], red2[4]);
410 self_type vmax(self_type a)
const
412 return self_type(_mm256_max_ps(m_value, a.m_value));
420 element_type
min()
const
423 auto sh1 = _mm256_permute_ps(m_value, 0xB1);
424 auto red1 = _mm256_min_ps(m_value, sh1);
427 auto sh2 = _mm256_permute_ps(red1, 0x4E);
428 auto red2 = _mm256_min_ps(red1, sh2);
431 return RAJA::min<element_type>(red2[0], red2[4]);
439 element_type min_n(camp::idx_t N)
const
452 return RAJA::min<element_type>(m_value[0], m_value[1]);
456 auto sh1 = _mm256_permute_ps(m_value, 0xB1);
461 sh1 = _mm256_blend_ps(sh1, m_value, 0x40);
464 auto red1 = _mm256_min_ps(m_value, sh1);
469 return RAJA::min<element_type>(red1[0], m_value[2]);
474 auto sh2 = _mm256_permute_ps(red1, 0x4E);
475 auto red2 = _mm256_min_ps(red1, sh2);
483 return RAJA::min<element_type>(red2[0], m_value[4]);
487 return RAJA::min<element_type>(red2[0], red1[4]);
491 return RAJA::min<element_type>(red2[0], red2[4]);
499 self_type vmin(self_type a)
const
501 return self_type(_mm256_min_ps(m_value, a.m_value));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155