RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
avx2_int32.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifdef __AVX2__
21 
22 #ifndef RAJA_policy_vector_register_avx2_int32_HPP
23 #define RAJA_policy_vector_register_avx2_int32_HPP
24 
25 #include "RAJA/config.hpp"
26 #include "RAJA/util/macros.hpp"
28 
29 // Include SIMD intrinsics header file
30 #include <immintrin.h>
31 #include <cmath>
32 
33 namespace RAJA
34 {
35 namespace expt
36 {
37 
38 template<>
39 class Register<int32_t, avx2_register>
40  : public internal::expt::RegisterBase<Register<int32_t, avx2_register>>
41 {
42 public:
43  using base_type =
44  internal::expt::RegisterBase<Register<int32_t, avx2_register>>;
45 
46  using register_policy = avx2_register;
47  using self_type = Register<int32_t, avx2_register>;
48  using element_type = int32_t;
49  using register_type = __m256i;
50 
51  using int_vector_type = Register<int32_t, avx2_register>;
52 
53 
54 private:
55  register_type m_value;
56 
57  RAJA_INLINE
58  __m256i createMask(camp::idx_t N) const
59  {
60  // Generate a mask
61  return _mm256_set_epi32(N >= 8 ? -1 : 0, N >= 7 ? -1 : 0, N >= 6 ? -1 : 0,
62  N >= 5 ? -1 : 0, N >= 4 ? -1 : 0, N >= 3 ? -1 : 0,
63  N >= 2 ? -1 : 0, N >= 1 ? -1 : 0);
64  }
65 
66  RAJA_INLINE
67  __m256i createStridedOffsets(camp::idx_t stride) const
68  {
69  // Generate a strided offset list
70  return _mm256_set_epi32(7 * stride, 6 * stride, 5 * stride, 4 * stride,
71  3 * stride, 2 * stride, stride, 0);
72  }
73 
74  RAJA_INLINE
75  __m256i createPermute1(camp::idx_t N) const
76  {
77  // Generate a permutation for first round of min/max routines
78  return _mm256_set_epi32(N >= 7 ? 6 : 0, N >= 8 ? 7 : 0, N >= 5 ? 4 : 0,
79  N >= 6 ? 5 : 0, N >= 3 ? 2 : 0, N >= 4 ? 3 : 0,
80  N >= 1 ? 0 : 0, N >= 2 ? 1 : 0);
81  }
82 
83  RAJA_INLINE
84  __m256i createPermute2(camp::idx_t N) const
85  {
86  // Generate a permutation for second round of min/max routines
87  return _mm256_set_epi32(N >= 6 ? 5 : 0, N >= 5 ? 4 : 0, N >= 8 ? 7 : 0,
88  N >= 7 ? 6 : 0, N >= 2 ? 1 : 0, N >= 1 ? 0 : 0,
89  N >= 4 ? 3 : 0, N >= 2 ? 2 : 0);
90  }
91 
92 public:
93  static constexpr camp::idx_t s_num_elem = 8;
94 
98  RAJA_INLINE
99  Register() : m_value(_mm256_setzero_si256()) {}
100 
104  RAJA_INLINE
105  Register(element_type x0,
106  element_type x1,
107  element_type x2,
108  element_type x3,
109  element_type x4,
110  element_type x5,
111  element_type x6,
112  element_type x7)
113  : m_value(_mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0))
114  {}
115 
119  RAJA_INLINE
120  explicit Register(register_type const& c) : m_value(c) {}
121 
125  RAJA_INLINE
126  Register(self_type const& c) : base_type(c), m_value(c.m_value) {}
127 
131  RAJA_INLINE
132  self_type& operator=(self_type const& c)
133  {
134  m_value = c.m_value;
135  return *this;
136  }
137 
142  RAJA_INLINE
143  Register(element_type const& c) : m_value(_mm256_set1_epi32(c)) {}
144 
148  RAJA_INLINE
149  constexpr register_type get_register() const { return m_value; }
150 
155  RAJA_INLINE
156  self_type& load_packed(element_type const* ptr)
157  {
158  m_value = _mm256_loadu_si256((__m256i const*)ptr);
159  return *this;
160  }
161 
167  RAJA_INLINE
168  self_type& load_packed_n(element_type const* ptr, camp::idx_t N)
169  {
170  m_value = _mm256_maskload_epi32(ptr, createMask(N));
171  return *this;
172  }
173 
178  RAJA_INLINE
179  self_type& load_strided(element_type const* ptr, camp::idx_t stride)
180  {
181  m_value = _mm256_i32gather_epi32(ptr, createStridedOffsets(stride),
182  sizeof(element_type));
183  return *this;
184  }
185 
191  RAJA_INLINE
192  self_type& load_strided_n(element_type const* ptr,
193  camp::idx_t stride,
194  camp::idx_t N)
195  {
196  m_value = _mm256_mask_i32gather_epi32(_mm256_setzero_si256(), ptr,
197  createStridedOffsets(stride),
198  createMask(N), sizeof(element_type));
199  return *this;
200  }
201 
206  RAJA_INLINE
207  self_type const& store_packed(element_type* ptr) const
208  {
209  _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), m_value);
210  return *this;
211  }
212 
217  RAJA_INLINE
218  self_type const& store_packed_n(element_type* ptr, camp::idx_t N) const
219  {
220  _mm256_maskstore_epi32(ptr, createMask(N), m_value);
221  return *this;
222  }
223 
228  RAJA_INLINE
229  self_type const& store_strided(element_type* ptr, camp::idx_t stride) const
230  {
231  for (camp::idx_t i = 0; i < 8; ++i)
232  {
233  ptr[i * stride] = get(i);
234  }
235  return *this;
236  }
237 
242  RAJA_INLINE
243  self_type const& store_strided_n(element_type* ptr,
244  camp::idx_t stride,
245  camp::idx_t N) const
246  {
247  for (camp::idx_t i = 0; i < N; ++i)
248  {
249  ptr[i * stride] = get(i);
250  }
251  return *this;
252  }
253 
259  RAJA_INLINE
260  element_type get(camp::idx_t i) const
261  {
262  // got to be a nicer way to do this!?!?
263  switch (i)
264  {
265  case 0:
266  return _mm256_extract_epi32(m_value, 0);
267  case 1:
268  return _mm256_extract_epi32(m_value, 1);
269  case 2:
270  return _mm256_extract_epi32(m_value, 2);
271  case 3:
272  return _mm256_extract_epi32(m_value, 3);
273  case 4:
274  return _mm256_extract_epi32(m_value, 4);
275  case 5:
276  return _mm256_extract_epi32(m_value, 5);
277  case 6:
278  return _mm256_extract_epi32(m_value, 6);
279  case 7:
280  return _mm256_extract_epi32(m_value, 7);
281  }
282  return 0;
283  }
284 
290  RAJA_INLINE
291  self_type& set(element_type value, camp::idx_t i)
292  {
293  // got to be a nicer way to do this!?!?
294  switch (i)
295  {
296  case 0:
297  m_value = _mm256_insert_epi32(m_value, value, 0);
298  break;
299  case 1:
300  m_value = _mm256_insert_epi32(m_value, value, 1);
301  break;
302  case 2:
303  m_value = _mm256_insert_epi32(m_value, value, 2);
304  break;
305  case 3:
306  m_value = _mm256_insert_epi32(m_value, value, 3);
307  break;
308  case 4:
309  m_value = _mm256_insert_epi32(m_value, value, 4);
310  break;
311  case 5:
312  m_value = _mm256_insert_epi32(m_value, value, 5);
313  break;
314  case 6:
315  m_value = _mm256_insert_epi32(m_value, value, 6);
316  break;
317  case 7:
318  m_value = _mm256_insert_epi32(m_value, value, 7);
319  break;
320  }
321 
322  return *this;
323  }
324 
326 
327  RAJA_INLINE
328  self_type& broadcast(element_type const& value)
329  {
330  m_value = _mm256_set1_epi32(value);
331  return *this;
332  }
333 
335 
336  RAJA_INLINE
337  self_type& copy(self_type const& src)
338  {
339  m_value = src.m_value;
340  return *this;
341  }
342 
344 
345  RAJA_INLINE
346  self_type add(self_type const& b) const
347  {
348  return self_type(_mm256_add_epi32(m_value, b.m_value));
349  }
350 
352 
353  RAJA_INLINE
354  self_type subtract(self_type const& b) const
355  {
356  return self_type(_mm256_sub_epi32(m_value, b.m_value));
357  }
358 
360 
361  RAJA_INLINE
362  self_type multiply(self_type const& b) const
363  {
364 
365  // the AVX2 epi32 multiply only multiplies the even elements
366  // and provides 64-bit results
367  // need to do some repacking to get this to work
368 
369  // multiply 0, 2, 4, 6
370  auto prod_even = _mm256_mul_epi32(m_value, b.m_value);
371 
372  // Swap 32-bit words
373  auto sh_a = _mm256_castps_si256(
374  _mm256_permute_ps(_mm256_castsi256_ps(m_value), 0xB1));
375 
376  auto sh_b = _mm256_castps_si256(
377  _mm256_permute_ps(_mm256_castsi256_ps(b.m_value), 0xB1));
378 
379  // multiply 1, 3, 5, 7
380  auto prod_odd = _mm256_mul_epi32(sh_a, sh_b);
381 
382  // Stitch prod_odd and prod_even back together
383  auto sh_odd = _mm256_castps_si256(
384  _mm256_permute_ps(_mm256_castsi256_ps(prod_odd), 0xB1));
385 
386  return self_type(_mm256_blend_epi32(prod_even, sh_odd, 0xAA));
387  }
388 
390 
391  RAJA_INLINE
392  self_type divide(self_type const& b) const
393  {
394  // AVX2 does not supply an integer divide, so do it manually
395  return self_type(_mm256_set_epi32(get(7) / b.get(7), get(6) / b.get(6),
396  get(5) / b.get(5), get(4) / b.get(4),
397  get(3) / b.get(3), get(2) / b.get(2),
398  get(1) / b.get(1), get(0) / b.get(0)));
399  }
400 
402 
403  RAJA_INLINE
404  self_type divide_n(self_type const& b, camp::idx_t N) const
405  {
406  // AVX2 does not supply an integer divide, so do it manually
407  return self_type(_mm256_set_epi32(
408  N >= 8 ? get(7) / b.get(7) : 0, N >= 7 ? get(6) / b.get(6) : 0,
409  N >= 6 ? get(5) / b.get(5) : 0, N >= 5 ? get(4) / b.get(4) : 0,
410  N >= 4 ? get(3) / b.get(3) : 0, N >= 3 ? get(2) / b.get(2) : 0,
411  N >= 2 ? get(1) / b.get(1) : 0, N >= 1 ? get(0) / b.get(0) : 0));
412  }
413 
418  RAJA_INLINE
419  element_type sum() const
420  {
421  // swap odd-even pairs and add
422  auto sh1 = _mm256_castps_si256(
423  _mm256_permute_ps(_mm256_castsi256_ps(m_value), 0xB1));
424  auto red1 = _mm256_add_epi32(m_value, sh1);
425 
426 
427  // swap odd-even quads and add
428  auto sh2 =
429  _mm256_castps_si256(_mm256_permute_ps(_mm256_castsi256_ps(red1), 0x4E));
430  auto red2 = _mm256_add_epi32(red1, sh2);
431 
432  return _mm256_extract_epi32(red2, 0) + _mm256_extract_epi32(red2, 4);
433  }
434 
439  RAJA_INLINE
440  element_type max() const
441  {
442 
443  // swap odd-even pairs and add
444  auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(8));
445  auto red1 = _mm256_max_epi32(m_value, sh1);
446 
447  // swap odd-even quads and add
448  auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(8));
449  auto red2 = _mm256_max_epi32(red1, sh2);
450 
451  return std::max<element_type>(_mm256_extract_epi32(red2, 0),
452  _mm256_extract_epi32(red2, 4));
453  }
454 
459  RAJA_INLINE
460  element_type max_n(camp::idx_t N) const
461  {
462  // Some simple cases
463  if (N <= 0 || N > 8)
464  {
466  }
467  if (N == 1)
468  {
469  return get(0);
470  }
471 
472  if (N == 2)
473  {
474  return std::max<element_type>(get(0), get(1));
475  }
476 
477  // swap odd-even pairs and add
478  auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(N));
479  auto red1 = _mm256_max_epi32(m_value, sh1);
480 
481  if (N == 3)
482  {
483  return std::max<element_type>(_mm256_extract_epi32(red1, 0), get(2));
484  }
485  if (N == 4)
486  {
487  return std::max<element_type>(_mm256_extract_epi32(red1, 0),
488  _mm256_extract_epi32(red1, 2));
489  }
490 
491  // swap odd-even quads and add
492  auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(N));
493  auto red2 = _mm256_max_epi32(red1, sh2);
494 
495  return std::max<element_type>(_mm256_extract_epi32(red2, 0),
496  _mm256_extract_epi32(red2, 4));
497  }
498 
503  RAJA_INLINE
504  self_type vmax(self_type a) const
505  {
506  return self_type(_mm256_max_epi32(m_value, a.m_value));
507  }
508 
513  RAJA_INLINE
514  element_type min() const
515  {
516 
517  // swap odd-even pairs and add
518  auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(8));
519  auto red1 = _mm256_min_epi32(m_value, sh1);
520 
521 
522  // swap odd-even quads and add
523  auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(8));
524  auto red2 = _mm256_min_epi32(red1, sh2);
525 
526  return std::min<element_type>(_mm256_extract_epi32(red2, 0),
527  _mm256_extract_epi32(red2, 4));
528  }
529 
534  RAJA_INLINE
535  element_type min_n(camp::idx_t N) const
536  {
537  // Some simple cases
538  if (N <= 0 || N > 8)
539  {
541  }
542  if (N == 1)
543  {
544  return get(0);
545  }
546 
547  if (N == 2)
548  {
549  return std::min<element_type>(get(0), get(1));
550  }
551 
552  // swap odd-even pairs and add
553  auto sh1 = _mm256_permutevar8x32_epi32(m_value, createPermute1(N));
554  auto red1 = _mm256_min_epi32(m_value, sh1);
555 
556  if (N == 3)
557  {
558  return std::min<element_type>(_mm256_extract_epi32(red1, 0), get(2));
559  }
560  if (N == 4)
561  {
562  return std::min<element_type>(_mm256_extract_epi32(red1, 0),
563  _mm256_extract_epi32(red1, 2));
564  }
565 
566  // swap odd-even quads and add
567  auto sh2 = _mm256_permutevar8x32_epi32(red1, createPermute2(N));
568  auto red2 = _mm256_min_epi32(red1, sh2);
569 
570  return std::min<element_type>(_mm256_extract_epi32(red2, 0),
571  _mm256_extract_epi32(red2, 4));
572  }
573 
578  RAJA_INLINE
579  self_type vmin(self_type a) const
580  {
581  return self_type(_mm256_min_epi32(m_value, a.m_value));
582  }
583 };
584 
585 
586 } // namespace expt
587 
588 } // namespace RAJA
589 
590 
591 #endif
592 
593 #endif //__AVX2__
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155