RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
avx2_float.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifdef __AVX2__
21 
22 #ifndef RAJA_policy_vector_register_avx2_float_HPP
23 #define RAJA_policy_vector_register_avx2_float_HPP
24 
25 #include "RAJA/config.hpp"
26 #include "RAJA/util/macros.hpp"
28 
29 // Include SIMD intrinsics header file
30 #include <immintrin.h>
31 #include <cmath>
32 
33 namespace RAJA
34 {
35 namespace expt
36 {
37 template<>
38 class Register<float, avx2_register>
39  : public internal::expt::RegisterBase<Register<float, avx2_register>>
40 {
41 public:
42  using base_type =
43  internal::expt::RegisterBase<Register<float, avx2_register>>;
44 
45  using register_policy = avx2_register;
46  using self_type = Register<float, avx2_register>;
47  using element_type = float;
48  using register_type = __m256;
49 
50  using int_vector_type = Register<int32_t, avx2_register>;
51 
52 
53 private:
54  register_type m_value;
55 
56  RAJA_INLINE
57  __m256i createMask(camp::idx_t N) const
58  {
59  // Generate a mask
60  return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
61  N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
62  N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
63  }
64 
65  RAJA_INLINE
66  __m256i createStridedOffsets(camp::idx_t stride) const
67  {
68  // Generate a strided offset list
69  return _mm256_set_epi32(7 * stride, 6 * stride, 5 * stride, 4 * stride,
70  3 * stride, 2 * stride, stride, 0);
71  }
72 
73  RAJA_INLINE
74  __m256i createPermute1(camp::idx_t N) const
75  {
76  // Generate a permutation for first round of min/max routines
77  return _mm256_set_epi32(N >= 7 ? 6 : 0, N >= 8 ? 7 : 0, N >= 5 ? 4 : 0,
78  N >= 6 ? 5 : 0, N >= 3 ? 2 : 0, N >= 4 ? 3 : 0,
79  N >= 1 ? 0 : 0, N >= 2 ? 1 : 0);
80  }
81 
82  RAJA_INLINE
83  __m256i createPermute2(camp::idx_t N) const
84  {
85  // Generate a permutation for second round of min/max routines
86  return _mm256_set_epi32(N >= 6 ? 5 : 0, N >= 5 ? 4 : 0, N >= 8 ? 7 : 0,
87  N >= 7 ? 6 : 0, N >= 2 ? 1 : 0, N >= 1 ? 0 : 0,
88  N >= 4 ? 3 : 0, N >= 2 ? 2 : 0);
89  }
90 
91 public:
92  static constexpr camp::idx_t s_num_elem = 8;
93 
97  RAJA_INLINE
98  Register() : m_value(_mm256_setzero_ps()) {}
99 
103  RAJA_INLINE
104  Register(element_type x0,
105  element_type x1,
106  element_type x2,
107  element_type x3,
108  element_type x4,
109  element_type x5,
110  element_type x6,
111  element_type x7)
112  : m_value(_mm256_set_ps(x7, x6, x5, x4, x3, x2, x1, x0))
113  {}
114 
118  RAJA_INLINE
119  explicit Register(register_type const& c) : m_value(c) {}
120 
124  RAJA_INLINE
125  Register(self_type const& c) : base_type(c), m_value(c.m_value) {}
126 
130  RAJA_INLINE
131  self_type& operator=(self_type const& c)
132  {
133  m_value = c.m_value;
134  return *this;
135  }
136 
141  RAJA_INLINE
142  Register(element_type const& c) : m_value(_mm256_set1_ps(c)) {}
143 
147  RAJA_INLINE
148  constexpr register_type get_register() const { return m_value; }
149 
154  RAJA_INLINE
155  self_type& load_packed(element_type const* ptr)
156  {
157  m_value = _mm256_loadu_ps(ptr);
158  return *this;
159  }
160 
166  RAJA_INLINE
167  self_type& load_packed_n(element_type const* ptr, camp::idx_t N)
168  {
169  m_value = _mm256_maskload_ps(ptr, createMask(N));
170  return *this;
171  }
172 
177  RAJA_INLINE
178  self_type& load_strided(element_type const* ptr, camp::idx_t stride)
179  {
180  m_value = _mm256_i32gather_ps(ptr, createStridedOffsets(stride),
181  sizeof(element_type));
182  return *this;
183  }
184 
190  RAJA_INLINE
191  self_type& load_strided_n(element_type const* ptr,
192  camp::idx_t stride,
193  camp::idx_t N)
194  {
195  m_value = _mm256_mask_i32gather_ps(
196  _mm256_setzero_ps(), ptr, createStridedOffsets(stride),
197  _mm256_castsi256_ps(createMask(N)), sizeof(element_type));
198  return *this;
199  }
200 
205  RAJA_INLINE
206  self_type const& store_packed(element_type* ptr) const
207  {
208  _mm256_storeu_ps(ptr, m_value);
209  return *this;
210  }
211 
216  RAJA_INLINE
217  self_type const& store_packed_n(element_type* ptr, camp::idx_t N) const
218  {
219  _mm256_maskstore_ps(ptr, createMask(N), m_value);
220  return *this;
221  }
222 
227  RAJA_INLINE
228  self_type const& store_strided(element_type* ptr, camp::idx_t stride) const
229  {
230  for (camp::idx_t i = 0; i < 8; ++i)
231  {
232  ptr[i * stride] = m_value[i];
233  }
234  return *this;
235  }
236 
241  RAJA_INLINE
242  self_type const& store_strided_n(element_type* ptr,
243  camp::idx_t stride,
244  camp::idx_t N) const
245  {
246  for (camp::idx_t i = 0; i < N; ++i)
247  {
248  ptr[i * stride] = m_value[i];
249  }
250  return *this;
251  }
252 
258  RAJA_INLINE
259  element_type get(camp::idx_t i) const { return m_value[i]; }
260 
266  RAJA_INLINE
267  self_type& set(element_type value, camp::idx_t i)
268  {
269  m_value[i] = value;
270  return *this;
271  }
272 
274 
275  RAJA_INLINE
276  self_type& broadcast(element_type const& value)
277  {
278  m_value = _mm256_set1_ps(value);
279  return *this;
280  }
281 
283 
284  RAJA_INLINE
285  self_type& copy(self_type const& src)
286  {
287  m_value = src.m_value;
288  return *this;
289  }
290 
292 
293  RAJA_INLINE
294  self_type add(self_type const& b) const
295  {
296  return self_type(_mm256_add_ps(m_value, b.m_value));
297  }
298 
300 
301  RAJA_INLINE
302  self_type subtract(self_type const& b) const
303  {
304  return self_type(_mm256_sub_ps(m_value, b.m_value));
305  }
306 
308 
309  RAJA_INLINE
310  self_type multiply(self_type const& b) const
311  {
312  return self_type(_mm256_mul_ps(m_value, b.m_value));
313  }
314 
316 
317  RAJA_INLINE
318  self_type divide(self_type const& b) const
319  {
320  return self_type(_mm256_div_ps(m_value, b.m_value));
321  }
322 
324 
325  RAJA_INLINE
326  self_type divide_n(self_type const& b, camp::idx_t N) const
327  {
328  // AVX2 does not supply a masked divide
329  return self_type(_mm256_set_ps(
330  N >= 8 ? get(7) / b.get(7) : 0, N >= 7 ? get(6) / b.get(6) : 0,
331  N >= 6 ? get(5) / b.get(5) : 0, N >= 5 ? get(4) / b.get(4) : 0,
332  N >= 4 ? get(3) / b.get(3) : 0, N >= 3 ? get(2) / b.get(2) : 0,
333  N >= 2 ? get(1) / b.get(1) : 0, N >= 1 ? get(0) / b.get(0) : 0));
334  }
335 
336 // only use FMA's if the compiler has them turned on
337 #ifdef __FMA__
338  RAJA_INLINE
339 
341  self_type multiply_add(self_type const& b, self_type const& c) const
342  {
343  return self_type(_mm256_fmadd_ps(m_value, b.m_value, c.m_value));
344  }
345 
346  RAJA_INLINE
347 
349  self_type multiply_subtract(self_type const& b, self_type const& c) const
350  {
351  return self_type(_mm256_fmsub_ps(m_value, b.m_value, c.m_value));
352  }
353 #endif
354 
359  RAJA_INLINE
360  element_type sum() const
361  {
362  // swap odd-even pairs and add
363  auto sh1 = _mm256_permute_ps(m_value, 0xB1);
364  auto red1 = _mm256_add_ps(m_value, sh1);
365 
366  // swap odd-even quads and add
367  auto sh2 = _mm256_permute_ps(red1, 0x4E);
368  auto red2 = _mm256_add_ps(red1, sh2);
369 
370  return red2[0] + red2[4];
371  }
372 
377  RAJA_INLINE
378  element_type max() const
379  {
380 
381  // swap odd-even pairs and add
382  auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(8));
383  auto red1 = _mm256_max_ps(m_value, sh1);
384 
385  // swap odd-even quads and add
386  auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(8));
387  auto red2 = _mm256_max_ps(red1, sh2);
388 
389  return std::max<element_type>(red2[0], red2[4]);
390  }
391 
396  RAJA_INLINE
397  element_type max_n(camp::idx_t N) const
398  {
399  // Some simple cases
400  if (N <= 0 || N > 8)
401  {
403  }
404  if (N == 1)
405  {
406  return m_value[0];
407  }
408  if (N == 2)
409  {
410  return std::max<element_type>(m_value[0], m_value[1]);
411  }
412 
413  // swap odd-even pairs and add
414  auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(N));
415  auto red1 = _mm256_max_ps(m_value, sh1);
416 
417  if (N == 3)
418  {
419  return std::max<element_type>(red1[0], m_value[2]);
420  }
421  if (N == 4)
422  {
423  return std::max<element_type>(red1[0], red1[2]);
424  }
425 
426  // swap odd-even quads and add
427  auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(N));
428  auto red2 = _mm256_max_ps(red1, sh2);
429 
430  return std::max<element_type>(red2[0], red2[4]);
431  }
432 
437  RAJA_INLINE
438  self_type vmax(self_type a) const
439  {
440  return self_type(_mm256_max_ps(m_value, a.m_value));
441  }
442 
447  RAJA_INLINE
448  element_type min() const
449  {
450 
451  // swap odd-even pairs and add
452  auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(8));
453  auto red1 = _mm256_min_ps(m_value, sh1);
454 
455  // swap odd-even quads and add
456  auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(8));
457  auto red2 = _mm256_min_ps(red1, sh2);
458 
459  return std::min<element_type>(red2[0], red2[4]);
460  }
461 
466  RAJA_INLINE
467  element_type min_n(camp::idx_t N) const
468  {
469  // Some simple cases
470  if (N <= 0 || N > 8)
471  {
473  }
474  if (N == 1)
475  {
476  return m_value[0];
477  }
478  if (N == 2)
479  {
480  return std::min<element_type>(m_value[0], m_value[1]);
481  }
482 
483  // swap odd-even pairs and add
484  auto sh1 = _mm256_permutevar8x32_ps(m_value, createPermute1(N));
485  auto red1 = _mm256_min_ps(m_value, sh1);
486 
487  if (N == 3)
488  {
489  return std::min<element_type>(red1[0], m_value[2]);
490  }
491  if (N == 4)
492  {
493  return std::min<element_type>(red1[0], red1[2]);
494  }
495 
496  // swap odd-even quads and add
497  auto sh2 = _mm256_permutevar8x32_ps(red1, createPermute2(N));
498  auto red2 = _mm256_min_ps(red1, sh2);
499 
500  return std::min<element_type>(red2[0], red2[4]);
501  }
502 
507  RAJA_INLINE
508  self_type vmin(self_type a) const
509  {
510  return self_type(_mm256_min_ps(m_value, a.m_value));
511  }
512 };
513 
514 
515 } // namespace expt
516 
517 } // namespace RAJA
518 
519 
520 #endif
521 
522 #endif //__AVX2__
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155