RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
avx512_float.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifdef __AVX512F__
21 
22 #ifndef RAJA_policy_vector_register_avx512_float_HPP
23 #define RAJA_policy_vector_register_avx512_float_HPP
24 
25 #include "RAJA/config.hpp"
26 #include "RAJA/util/macros.hpp"
28 
29 // Include SIMD intrinsics header file
30 #include <immintrin.h>
31 #include <cmath>
32 
33 namespace RAJA
34 {
35 namespace expt
36 {
37 template<>
38 class Register<float, avx512_register>
39  : public internal::expt::RegisterBase<Register<float, avx512_register>>
40 {
41 public:
42  using base_type =
43  internal::expt::RegisterBase<Register<float, avx512_register>>;
44 
45  using register_policy = avx512_register;
46  using self_type = Register<float, avx512_register>;
47  using element_type = float;
48  using register_type = __m512;
49 
50  using int_vector_type = Register<int32_t, avx512_register>;
51 
52 
53 private:
54  register_type m_value;
55 
56  RAJA_INLINE
57  __mmask16 createMask(camp::idx_t N) const
58  {
59  // Generate a mask
60  switch (N)
61  {
62  case 0:
63  return __mmask16(0x0000);
64  case 1:
65  return __mmask16(0x0001);
66  case 2:
67  return __mmask16(0x0003);
68  case 3:
69  return __mmask16(0x0007);
70  case 4:
71  return __mmask16(0x000F);
72  case 5:
73  return __mmask16(0x001F);
74  case 6:
75  return __mmask16(0x003F);
76  case 7:
77  return __mmask16(0x007F);
78  case 8:
79  return __mmask16(0x00FF);
80  case 9:
81  return __mmask16(0x01FF);
82  case 10:
83  return __mmask16(0x03FF);
84  case 11:
85  return __mmask16(0x07FF);
86  case 12:
87  return __mmask16(0x0FFF);
88  case 13:
89  return __mmask16(0x1FFF);
90  case 14:
91  return __mmask16(0x3FFF);
92  case 15:
93  return __mmask16(0x7FFF);
94  case 16:
95  return __mmask16(0xFFFF);
96  }
97  return __mmask16(0);
98  }
99 
100  RAJA_INLINE
101  __m512i createStridedOffsets(camp::idx_t stride) const
102  {
103  // Generate a strided offset list
104  auto vstride = _mm512_set1_epi32(stride);
105  auto vseq =
106  _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
107  return _mm512_mullo_epi32(vstride, vseq);
108  }
109 
110 public:
111  static constexpr camp::idx_t s_num_elem = 16;
112 
116  // AVX512F
117  RAJA_INLINE
118  Register() : base_type(), m_value(_mm512_setzero_ps()) {}
119 
123  RAJA_INLINE
124  explicit Register(register_type const& c) : base_type(), m_value(c) {}
125 
129  RAJA_INLINE
130  Register(self_type const& c) : base_type(), m_value(c.m_value) {}
131 
135  RAJA_INLINE
136  self_type& operator=(self_type const& c)
137  {
138  m_value = c.m_value;
139  return *this;
140  }
141 
146  // AVX512F
147  RAJA_INLINE
148  Register(element_type const& c) : base_type(), m_value(_mm512_set1_ps(c)) {}
149 
154  RAJA_INLINE
155  self_type& load_packed(element_type const* ptr)
156  {
157  // AVX512F
158  m_value = _mm512_loadu_ps(ptr);
159  return *this;
160  }
161 
167  RAJA_INLINE
168  self_type& load_packed_n(element_type const* ptr, camp::idx_t N)
169  {
170  // AVX512F
171  m_value = _mm512_mask_loadu_ps(_mm512_setzero_ps(), createMask(N), ptr);
172  return *this;
173  }
174 
179  RAJA_INLINE
180  self_type& load_strided(element_type const* ptr, camp::idx_t stride)
181  {
182  // AVX512F
183  m_value = _mm512_i32gather_ps(createStridedOffsets(stride), ptr,
184  sizeof(element_type));
185  return *this;
186  }
187 
193  RAJA_INLINE
194  self_type& load_strided_n(element_type const* ptr,
195  camp::idx_t stride,
196  camp::idx_t N)
197  {
198  // AVX512F
199  m_value = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), createMask(N),
200  createStridedOffsets(stride), ptr,
201  sizeof(element_type));
202  return *this;
203  }
204 
209  RAJA_INLINE
210  self_type const& store_packed(element_type* ptr) const
211  {
212  // AVX512F
213  _mm512_storeu_ps(ptr, m_value);
214  return *this;
215  }
216 
221  RAJA_INLINE
222  self_type const& store_packed_n(element_type* ptr, camp::idx_t N) const
223  {
224  // AVX512F
225  _mm512_mask_storeu_ps(ptr, createMask(N), m_value);
226  return *this;
227  }
228 
233  RAJA_INLINE
234  self_type const& store_strided(element_type* ptr, camp::idx_t stride) const
235  {
236  // AVX512F
237  _mm512_i32scatter_ps(ptr, createStridedOffsets(stride), m_value,
238  sizeof(element_type));
239  return *this;
240  }
241 
246  RAJA_INLINE
247  self_type const& store_strided_n(element_type* ptr,
248  camp::idx_t stride,
249  camp::idx_t N) const
250  {
251  // AVX512F
252  _mm512_mask_i32scatter_ps(ptr, createMask(N), createStridedOffsets(stride),
253  m_value, sizeof(element_type));
254  return *this;
255  }
256 
262  RAJA_INLINE
263  element_type get(camp::idx_t i) const { return m_value[i]; }
264 
270  RAJA_INLINE
271  self_type& set(element_type value, camp::idx_t i)
272  {
273  m_value[i] = value;
274  return *this;
275  }
276 
278 
279  RAJA_INLINE
280  self_type& broadcast(element_type const& value)
281  {
282  m_value = _mm512_set1_ps(value);
283  return *this;
284  }
285 
287 
288  RAJA_INLINE
289  self_type& copy(self_type const& src)
290  {
291  m_value = src.m_value;
292  return *this;
293  }
294 
296 
297  RAJA_INLINE
298  self_type add(self_type const& b) const
299  {
300  return self_type(_mm512_add_ps(m_value, b.m_value));
301  }
302 
304 
305  RAJA_INLINE
306  self_type subtract(self_type const& b) const
307  {
308  return self_type(_mm512_sub_ps(m_value, b.m_value));
309  }
310 
312 
313  RAJA_INLINE
314  self_type multiply(self_type const& b) const
315  {
316  return self_type(_mm512_mul_ps(m_value, b.m_value));
317  }
318 
320 
321  RAJA_INLINE
322  self_type divide(self_type const& b) const
323  {
324  return self_type(_mm512_div_ps(m_value, b.m_value));
325  }
326 
328 
329  RAJA_INLINE
330  self_type divide_n(self_type const& b, camp::idx_t N) const
331  {
332  return self_type(_mm512_maskz_div_ps(createMask(N), m_value, b.m_value));
333  }
334 
335 // only use FMA's if the compiler has them turned on
336 #ifdef __FMA__
337  RAJA_INLINE
338 
340  self_type multiply_add(self_type const& b, self_type const& c) const
341  {
342  return self_type(_mm512_fmadd_ps(m_value, b.m_value, c.m_value));
343  }
344 
345  RAJA_INLINE
346 
348  self_type multiply_subtract(self_type const& b, self_type const& c) const
349  {
350  return self_type(_mm512_fmsub_ps(m_value, b.m_value, c.m_value));
351  }
352 #endif
353 
358  RAJA_INLINE
359  element_type sum() const { return _mm512_reduce_add_ps(m_value); }
360 
365  RAJA_INLINE
366  element_type max() const { return _mm512_reduce_max_ps(m_value); }
367 
372  RAJA_INLINE
373  element_type max_n(camp::idx_t N) const
374  {
375  return _mm512_mask_reduce_max_ps(createMask(N), m_value);
376  }
377 
382  RAJA_INLINE
383  self_type vmax(self_type a) const
384  {
385  return self_type(_mm512_max_ps(m_value, a.m_value));
386  }
387 
392  RAJA_INLINE
393  element_type min() const { return _mm512_reduce_min_ps(m_value); }
394 
399  RAJA_INLINE
400  element_type min_n(camp::idx_t N) const
401  {
402  return _mm512_mask_reduce_min_ps(createMask(N), m_value);
403  }
404 
409  RAJA_INLINE
410  self_type vmin(self_type a) const
411  {
412  return self_type(_mm512_min_ps(m_value, a.m_value));
413  }
414 };
415 
416 
417 } // namespace expt
418 
419 } // namespace RAJA
420 
421 
422 #endif
423 
424 #endif //__AVX512F__
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155