22 #ifndef RAJA_policy_vector_register_avx512_float_HPP
23 #define RAJA_policy_vector_register_avx512_float_HPP
25 #include "RAJA/config.hpp"
30 #include <immintrin.h>
38 class Register<float, avx512_register>
39 :
public internal::expt::RegisterBase<Register<float, avx512_register>>
43 internal::expt::RegisterBase<Register<float, avx512_register>>;
45 using register_policy = avx512_register;
46 using self_type = Register<float, avx512_register>;
47 using element_type = float;
48 using register_type = __m512;
50 using int_vector_type = Register<int32_t, avx512_register>;
54 register_type m_value;
57 __mmask16 createMask(camp::idx_t N)
const
63 return __mmask16(0x0000);
65 return __mmask16(0x0001);
67 return __mmask16(0x0003);
69 return __mmask16(0x0007);
71 return __mmask16(0x000F);
73 return __mmask16(0x001F);
75 return __mmask16(0x003F);
77 return __mmask16(0x007F);
79 return __mmask16(0x00FF);
81 return __mmask16(0x01FF);
83 return __mmask16(0x03FF);
85 return __mmask16(0x07FF);
87 return __mmask16(0x0FFF);
89 return __mmask16(0x1FFF);
91 return __mmask16(0x3FFF);
93 return __mmask16(0x7FFF);
95 return __mmask16(0xFFFF);
101 __m512i createStridedOffsets(camp::idx_t stride)
const
104 auto vstride = _mm512_set1_epi32(stride);
106 _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
107 return _mm512_mullo_epi32(vstride, vseq);
111 static constexpr camp::idx_t s_num_elem = 16;
118 Register() : base_type(), m_value(_mm512_setzero_ps()) {}
124 explicit Register(register_type
const& c) : base_type(), m_value(c) {}
130 Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
136 self_type& operator=(self_type
const& c)
148 Register(element_type
const& c) : base_type(), m_value(_mm512_set1_ps(c)) {}
155 self_type& load_packed(element_type
const* ptr)
158 m_value = _mm512_loadu_ps(ptr);
168 self_type& load_packed_n(element_type
const* ptr, camp::idx_t N)
171 m_value = _mm512_mask_loadu_ps(_mm512_setzero_ps(), createMask(N), ptr);
180 self_type& load_strided(element_type
const* ptr, camp::idx_t stride)
183 m_value = _mm512_i32gather_ps(createStridedOffsets(stride), ptr,
184 sizeof(element_type));
194 self_type& load_strided_n(element_type
const* ptr,
199 m_value = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), createMask(N),
200 createStridedOffsets(stride), ptr,
201 sizeof(element_type));
210 self_type
const& store_packed(element_type* ptr)
const
213 _mm512_storeu_ps(ptr, m_value);
222 self_type
const& store_packed_n(element_type* ptr, camp::idx_t N)
const
225 _mm512_mask_storeu_ps(ptr, createMask(N), m_value);
234 self_type
const& store_strided(element_type* ptr, camp::idx_t stride)
const
237 _mm512_i32scatter_ps(ptr, createStridedOffsets(stride), m_value,
238 sizeof(element_type));
247 self_type
const& store_strided_n(element_type* ptr,
252 _mm512_mask_i32scatter_ps(ptr, createMask(N), createStridedOffsets(stride),
253 m_value,
sizeof(element_type));
263 element_type
get(camp::idx_t i)
const {
return m_value[i]; }
271 self_type& set(element_type value, camp::idx_t i)
280 self_type& broadcast(element_type
const& value)
282 m_value = _mm512_set1_ps(value);
289 self_type& copy(self_type
const& src)
291 m_value = src.m_value;
298 self_type add(self_type
const& b)
const
300 return self_type(_mm512_add_ps(m_value, b.m_value));
306 self_type subtract(self_type
const& b)
const
308 return self_type(_mm512_sub_ps(m_value, b.m_value));
314 self_type multiply(self_type
const& b)
const
316 return self_type(_mm512_mul_ps(m_value, b.m_value));
322 self_type divide(self_type
const& b)
const
324 return self_type(_mm512_div_ps(m_value, b.m_value));
330 self_type divide_n(self_type
const& b, camp::idx_t N)
const
332 return self_type(_mm512_maskz_div_ps(createMask(N), m_value, b.m_value));
340 self_type multiply_add(self_type
const& b, self_type
const& c)
const
342 return self_type(_mm512_fmadd_ps(m_value, b.m_value, c.m_value));
348 self_type multiply_subtract(self_type
const& b, self_type
const& c)
const
350 return self_type(_mm512_fmsub_ps(m_value, b.m_value, c.m_value));
359 element_type
sum()
const {
return _mm512_reduce_add_ps(m_value); }
366 element_type
max()
const {
return _mm512_reduce_max_ps(m_value); }
373 element_type max_n(camp::idx_t N)
const
375 return _mm512_mask_reduce_max_ps(createMask(N), m_value);
383 self_type vmax(self_type a)
const
385 return self_type(_mm512_max_ps(m_value, a.m_value));
393 element_type
min()
const {
return _mm512_reduce_min_ps(m_value); }
400 element_type min_n(camp::idx_t N)
const
402 return _mm512_mask_reduce_min_ps(createMask(N), m_value);
410 self_type vmin(self_type a)
const
412 return self_type(_mm512_min_ps(m_value, a.m_value));
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155