RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
avx_float.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifdef __AVX__
21 
22 #ifndef RAJA_policy_vector_register_avx_float_HPP
23 #define RAJA_policy_vector_register_avx_float_HPP
24 
25 #include "RAJA/config.hpp"
26 #include "RAJA/util/macros.hpp"
28 
29 // Include SIMD intrinsics header file
30 #include <immintrin.h>
31 #include <cmath>
32 
33 namespace RAJA
34 {
35 namespace expt
36 {
37 
38 template<>
39 class Register<float, avx_register>
40  : public internal::expt::RegisterBase<Register<float, avx_register>>
41 {
42 public:
43  using base_type = internal::expt::RegisterBase<Register<float, avx_register>>;
44 
45  using register_policy = avx_register;
46  using self_type = Register<float, avx_register>;
47  using element_type = float;
48  using register_type = __m256;
49 
50  using int_vector_type = Register<int32_t, avx_register>;
51 
52 
53 private:
54  register_type m_value;
55 
56  RAJA_INLINE
57  __m256i createMask(camp::idx_t N) const
58  {
59  // Generate a mask
60  return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
61  N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
62  N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
63  }
64 
65 public:
66  static constexpr camp::idx_t s_num_elem = 8;
67 
71  RAJA_INLINE
72  Register() : base_type(), m_value(_mm256_setzero_ps()) {}
73 
77  RAJA_INLINE
78  explicit Register(register_type const& c) : base_type(), m_value(c) {}
79 
83  RAJA_INLINE
84  Register(element_type x0,
85  element_type x1,
86  element_type x2,
87  element_type x3,
88  element_type x4,
89  element_type x5,
90  element_type x6,
91  element_type x7)
92  : m_value(_mm256_set_ps(x7, x6, x5, x4, x3, x2, x1, x0))
93  {}
94 
98  RAJA_INLINE
99  Register(self_type const& c) : base_type(), m_value(c.m_value) {}
100 
104  RAJA_INLINE
105  self_type& operator=(self_type const& c)
106  {
107  m_value = c.m_value;
108  return *this;
109  }
110 
115  RAJA_INLINE
116  Register(element_type const& c) : m_value(_mm256_set1_ps(c)) {}
117 
122  RAJA_INLINE
123  self_type& load_packed(element_type const* ptr)
124  {
125  m_value = _mm256_loadu_ps(ptr);
126  return *this;
127  }
128 
134  RAJA_INLINE
135  self_type& load_packed_n(element_type const* ptr, camp::idx_t N)
136  {
137  m_value = _mm256_maskload_ps(ptr, createMask(N));
138  return *this;
139  }
140 
145  RAJA_INLINE
146  self_type& load_strided(element_type const* ptr, camp::idx_t stride)
147  {
148  for (camp::idx_t i = 0; i < 8; ++i)
149  {
150  m_value[i] = ptr[i * stride];
151  }
152  return *this;
153  }
154 
160  RAJA_INLINE
161  self_type& load_strided_n(element_type const* ptr,
162  camp::idx_t stride,
163  camp::idx_t N)
164  {
165  m_value = _mm256_setzero_ps();
166  for (camp::idx_t i = 0; i < N; ++i)
167  {
168  m_value[i] = ptr[i * stride];
169  }
170  return *this;
171  }
172 
177  RAJA_INLINE
178  self_type const& store_packed(element_type* ptr) const
179  {
180  _mm256_storeu_ps(ptr, m_value);
181  return *this;
182  }
183 
188  RAJA_INLINE
189  self_type const& store_packed_n(element_type* ptr, camp::idx_t N) const
190  {
191  _mm256_maskstore_ps(ptr, createMask(N), m_value);
192  return *this;
193  }
194 
199  RAJA_INLINE
200  self_type const& store_strided(element_type* ptr, camp::idx_t stride) const
201  {
202  for (camp::idx_t i = 0; i < 8; ++i)
203  {
204  ptr[i * stride] = m_value[i];
205  }
206  return *this;
207  }
208 
213  RAJA_INLINE
214  self_type const& store_strided_n(element_type* ptr,
215  camp::idx_t stride,
216  camp::idx_t N) const
217  {
218  for (camp::idx_t i = 0; i < N; ++i)
219  {
220  ptr[i * stride] = m_value[i];
221  }
222  return *this;
223  }
224 
230  RAJA_INLINE
231  element_type get(camp::idx_t i) const { return m_value[i]; }
232 
238  RAJA_INLINE
239  self_type& set(element_type value, camp::idx_t i)
240  {
241  m_value[i] = value;
242  return *this;
243  }
244 
246 
247  RAJA_INLINE
248  self_type& broadcast(element_type const& value)
249  {
250  m_value = _mm256_set1_ps(value);
251  return *this;
252  }
253 
255 
256  RAJA_INLINE
257  self_type& copy(self_type const& src)
258  {
259  m_value = src.m_value;
260  return *this;
261  }
262 
264 
265  RAJA_INLINE
266  self_type add(self_type const& b) const
267  {
268  return self_type(_mm256_add_ps(m_value, b.m_value));
269  }
270 
272 
273  RAJA_INLINE
274  self_type subtract(self_type const& b) const
275  {
276  return self_type(_mm256_sub_ps(m_value, b.m_value));
277  }
278 
280 
281  RAJA_INLINE
282  self_type multiply(self_type const& b) const
283  {
284  return self_type(_mm256_mul_ps(m_value, b.m_value));
285  }
286 
288 
289  RAJA_INLINE
290  self_type divide(self_type const& b) const
291  {
292  return self_type(_mm256_div_ps(m_value, b.m_value));
293  }
294 
296 
297  RAJA_INLINE
298  self_type divide_n(self_type const& b, camp::idx_t N) const
299  {
300  // AVX2 does not supply a masked divide
301  return self_type(_mm256_set_ps(
302  N >= 8 ? get(7) / b.get(7) : 0, N >= 7 ? get(6) / b.get(6) : 0,
303  N >= 6 ? get(5) / b.get(5) : 0, N >= 5 ? get(4) / b.get(4) : 0,
304  N >= 4 ? get(3) / b.get(3) : 0, N >= 3 ? get(2) / b.get(2) : 0,
305  N >= 2 ? get(1) / b.get(1) : 0, N >= 1 ? get(0) / b.get(0) : 0));
306  }
307 
312  RAJA_INLINE
313  element_type sum() const
314  {
315  // swap odd-even pairs and add
316  auto sh1 = _mm256_permute_ps(m_value, 0xB1);
317  auto red1 = _mm256_add_ps(m_value, sh1);
318 
319  // swap odd-even quads and add
320  auto sh2 = _mm256_permute_ps(red1, 0x4E);
321  auto red2 = _mm256_add_ps(red1, sh2);
322 
323  return red2[0] + red2[4];
324  }
325 
330  RAJA_INLINE
331  element_type max() const
332  {
333  // swap odd-even pairs and combine
334  auto sh1 = _mm256_permute_ps(m_value, 0xB1);
335  auto red1 = _mm256_max_ps(m_value, sh1);
336 
337  // swap odd-even quads and combine
338  auto sh2 = _mm256_permute_ps(red1, 0x4E);
339  auto red2 = _mm256_max_ps(red1, sh2);
340 
341  // combine quads
342  return RAJA::max<element_type>(red2[0], red2[4]);
343  }
344 
349  RAJA_INLINE
350  element_type max_n(camp::idx_t N) const
351  {
352  // Some simple cases
353  if (N <= 0 || N > 8)
354  {
356  }
357  if (N == 1)
358  {
359  return m_value[0];
360  }
361  if (N == 2)
362  {
363  return RAJA::max<element_type>(m_value[0], m_value[1]);
364  }
365 
366  // swap odd-even pairs and add
367  auto sh1 = _mm256_permute_ps(m_value, 0xB1);
368 
369  if (N == 7)
370  {
371  // blend out the 8th lane of the permute
372  sh1 = _mm256_blend_ps(sh1, m_value, 0x40);
373  }
374 
375  auto red1 = _mm256_max_ps(m_value, sh1);
376 
377  // Some more simple shortcuts
378  if (N == 3)
379  {
380  return RAJA::max<element_type>(red1[0], m_value[2]);
381  }
382 
383 
384  // swap odd-even quads and add
385  auto sh2 = _mm256_permute_ps(red1, 0x4E);
386  auto red2 = _mm256_max_ps(red1, sh2);
387 
388  if (N == 4)
389  {
390  return red2[0];
391  }
392  if (N == 5)
393  {
394  return RAJA::max<element_type>(red2[0], m_value[4]);
395  }
396  if (N == 6)
397  {
398  return RAJA::max<element_type>(red2[0], red1[4]);
399  }
400 
401  // 7 or 8 lanes
402  return RAJA::max<element_type>(red2[0], red2[4]);
403  }
404 
409  RAJA_INLINE
410  self_type vmax(self_type a) const
411  {
412  return self_type(_mm256_max_ps(m_value, a.m_value));
413  }
414 
419  RAJA_INLINE
420  element_type min() const
421  {
422  // swap odd-even pairs and combine
423  auto sh1 = _mm256_permute_ps(m_value, 0xB1);
424  auto red1 = _mm256_min_ps(m_value, sh1);
425 
426  // swap odd-even quads and combine
427  auto sh2 = _mm256_permute_ps(red1, 0x4E);
428  auto red2 = _mm256_min_ps(red1, sh2);
429 
430  // combine quads
431  return RAJA::min<element_type>(red2[0], red2[4]);
432  }
433 
438  RAJA_INLINE
439  element_type min_n(camp::idx_t N) const
440  {
441  // Some simple cases
442  if (N <= 0 || N > 8)
443  {
445  }
446  if (N == 1)
447  {
448  return m_value[0];
449  }
450  if (N == 2)
451  {
452  return RAJA::min<element_type>(m_value[0], m_value[1]);
453  }
454 
455  // swap odd-even pairs and add
456  auto sh1 = _mm256_permute_ps(m_value, 0xB1);
457 
458  if (N == 7)
459  {
460  // blend out the 8th lane of the permute
461  sh1 = _mm256_blend_ps(sh1, m_value, 0x40);
462  }
463 
464  auto red1 = _mm256_min_ps(m_value, sh1);
465 
466  // Some more simple shortcuts
467  if (N == 3)
468  {
469  return RAJA::min<element_type>(red1[0], m_value[2]);
470  }
471 
472 
473  // swap odd-even quads and add
474  auto sh2 = _mm256_permute_ps(red1, 0x4E);
475  auto red2 = _mm256_min_ps(red1, sh2);
476 
477  if (N == 4)
478  {
479  return red2[0];
480  }
481  if (N == 5)
482  {
483  return RAJA::min<element_type>(red2[0], m_value[4]);
484  }
485  if (N == 6)
486  {
487  return RAJA::min<element_type>(red2[0], red1[4]);
488  }
489 
490  // 7 or 8 lanes
491  return RAJA::min<element_type>(red2[0], red2[4]);
492  }
493 
498  RAJA_INLINE
499  self_type vmin(self_type a) const
500  {
501  return self_type(_mm256_min_ps(m_value, a.m_value));
502  }
503 };
504 
505 
506 } // namespace expt
507 
508 } // namespace RAJA
509 
510 
511 #endif
512 
513 #endif //__AVX__
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155