22 #ifndef RAJA_policy_vector_register_avx512_int32_HPP
23 #define RAJA_policy_vector_register_avx512_int32_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
38 class Register<int32_t, avx512_register>
39 :
public internal::expt::RegisterBase<Register<int32_t, avx512_register>>
43 internal::expt::RegisterBase<Register<int32_t, avx512_register>>;
45 using register_policy = avx512_register;
46 using self_type = Register<int32_t, avx512_register>;
47 using element_type = int32_t;
48 using register_type = __m512i;
50 using int_vector_type = Register<int32_t, avx512_register>;
54 register_type m_value;
57 __mmask16 createMask(camp::idx_t N)
const
63 return __mmask16(0x0000);
65 return __mmask16(0x0001);
67 return __mmask16(0x0003);
69 return __mmask16(0x0007);
71 return __mmask16(0x000F);
73 return __mmask16(0x001F);
75 return __mmask16(0x003F);
77 return __mmask16(0x007F);
79 return __mmask16(0x00FF);
81 return __mmask16(0x01FF);
83 return __mmask16(0x03FF);
85 return __mmask16(0x07FF);
87 return __mmask16(0x0FFF);
89 return __mmask16(0x1FFF);
91 return __mmask16(0x3FFF);
93 return __mmask16(0x7FFF);
95 return __mmask16(0xFFFF);
101 __m512i createStridedOffsets(camp::idx_t stride)
const
104 auto vstride = _mm512_set1_epi32(stride);
106 _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
107 return _mm512_mullo_epi32(vstride, vseq);
111 static constexpr camp::idx_t s_num_elem = 16;
118 Register() : base_type(), m_value(_mm512_setzero_epi32()) {}
124 explicit Register(register_type
const& c) : base_type(), m_value(c) {}
130 Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
136 self_type& operator=(self_type
const& c)
148 Register(element_type
const& c) : base_type(), m_value(_mm512_set1_epi32(c))
156 self_type& load_packed(element_type
const* ptr)
159 #if defined(__GNUC__) && ((__GNUC__ >= 7) && (__GNUC__ <= 9))
160 m_value = _mm512_loadu_si512(ptr);
162 m_value = _mm512_loadu_epi32(ptr);
173 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
177 _mm512_mask_loadu_epi32(_mm512_setzero_epi32(), createMask(N), ptr);
186 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
189 m_value = _mm512_i32gather_epi32(createStridedOffsets(stride), ptr,
190 sizeof(element_type));
200 self_type& load_strided_n(element_type
const* ptr,
205 m_value = _mm512_mask_i32gather_epi32(_mm512_setzero_epi32(), createMask(N),
206 createStridedOffsets(stride), ptr,
207 sizeof(element_type));
216 self_type
const& store_packed(element_type* ptr)
const
219 #if defined(__GNUC__) && ((__GNUC__ >= 7) && (__GNUC__ <= 9))
220 _mm512_storeu_si512(ptr, m_value);
222 _mm512_storeu_epi32(ptr, m_value);
232 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
235 _mm512_mask_storeu_epi32(ptr, createMask(N), m_value);
244 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
247 _mm512_i32scatter_epi32(ptr, createStridedOffsets(stride), m_value,
248 sizeof(element_type));
257 self_type
const& store_strided_n(element_type* ptr,
262 _mm512_mask_i32scatter_epi32(ptr, createMask(N),
263 createStridedOffsets(stride), m_value,
264 sizeof(element_type));
274 element_type
get(camp::idx_t i)
const
277 #if defined(__GNUC__) && ((__GNUC__ >= 7) && (__GNUC__ <= 10))
278 #define _mm512_cvtsi512_si32(x) _mm_cvtsi128_si32(_mm512_castsi512_si128(x))
284 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 0));
286 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 1));
288 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 2));
290 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 3));
292 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 4));
294 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 5));
296 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 6));
298 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 7));
300 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 8));
302 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 9));
304 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 10));
306 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 11));
308 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 12));
310 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 13));
312 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 14));
314 return _mm512_cvtsi512_si32(_mm512_alignr_epi32(m_value, m_value, 15));
325 self_type& set(element_type value, camp::idx_t i)
327 m_value = _mm512_mask_set1_epi32(m_value, 1 << i, value);
334 self_type& broadcast(element_type
const& value)
336 m_value = _mm512_set1_epi32(value);
343 self_type& copy(self_type
const& src)
345 m_value = src.m_value;
352 self_type add(self_type
const& b)
const
354 return self_type(_mm512_add_epi32(m_value, b.m_value));
360 self_type subtract(self_type
const& b)
const
362 return self_type(_mm512_sub_epi32(m_value, b.m_value));
368 self_type multiply(self_type
const& b)
const
370 return self_type(_mm512_mullo_epi32(m_value, b.m_value));
376 self_type divide(self_type
const& b)
const
379 return self_type(_mm512_set_epi32(
380 get(15) / b.get(15),
get(14) / b.get(14),
get(13) / b.get(13),
381 get(12) / b.get(12),
get(11) / b.get(11),
get(10) / b.get(10),
382 get(9) / b.get(9),
get(8) / b.get(8),
get(7) / b.get(7),
383 get(6) / b.get(6),
get(5) / b.get(5),
get(4) / b.get(4),
384 get(3) / b.get(3),
get(2) / b.get(2),
get(1) / b.get(1),
391 self_type divide_n(self_type
const& b, camp::idx_t N)
const
394 return self_type(_mm512_set_epi32(
395 N >= 16 ?
get(15) / b.get(15) : 0, N >= 15 ?
get(14) / b.get(14) : 0,
396 N >= 14 ?
get(13) / b.get(13) : 0, N >= 13 ?
get(12) / b.get(12) : 0,
397 N >= 12 ?
get(11) / b.get(11) : 0, N >= 11 ?
get(10) / b.get(10) : 0,
398 N >= 10 ?
get(9) / b.get(9) : 0, N >= 9 ?
get(8) / b.get(8) : 0,
399 N >= 8 ?
get(7) / b.get(7) : 0, N >= 7 ?
get(6) / b.get(6) : 0,
400 N >= 6 ?
get(5) / b.get(5) : 0, N >= 5 ?
get(4) / b.get(4) : 0,
401 N >= 4 ?
get(3) / b.get(3) : 0, N >= 3 ?
get(2) / b.get(2) : 0,
402 N >= 2 ?
get(1) / b.get(1) : 0, N >= 1 ?
get(0) / b.get(0) : 0));
410 element_type
sum()
const {
return _mm512_reduce_add_epi32(m_value); }
417 element_type
max()
const {
return _mm512_reduce_max_epi32(m_value); }
424 element_type max_n(camp::idx_t N)
const
426 return _mm512_mask_reduce_max_epi32(createMask(N), m_value);
434 self_type vmax(self_type a)
const
436 return self_type(_mm512_max_epi32(m_value, a.m_value));
444 element_type
min()
const {
return _mm512_reduce_min_epi32(m_value); }
451 element_type
min(camp::idx_t N)
const
453 return _mm512_mask_reduce_min_epi32(createMask(N), m_value);
461 self_type vmin(self_type a)
const
463 return self_type(_mm512_min_epi32(m_value, a.m_value));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155