RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
avx_int64.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifdef __AVX__
21 
22 #ifndef RAJA_policy_vector_register_avx_int64_HPP
23 #define RAJA_policy_vector_register_avx_int64_HPP
24 
25 #include "RAJA/config.hpp"
26 #include "RAJA/util/macros.hpp"
28 
29 // Include SIMD intrinsics header file
30 #include <immintrin.h>
31 #include <cmath>
32 
33 namespace RAJA
34 {
35 namespace expt
36 {
37 template<>
38 class Register<int64_t, avx_register>
39  : public internal::expt::RegisterBase<Register<int64_t, avx_register>>
40 {
41 public:
42  using base_type =
43  internal::expt::RegisterBase<Register<int64_t, avx_register>>;
44 
45  using register_policy = avx_register;
46  using self_type = Register<int64_t, avx_register>;
47  using element_type = int64_t;
48  using register_type = __m256i;
49 
50  using int_vector_type = Register<int64_t, avx_register>;
51 
52 
53 private:
54  register_type m_value;
55 
56  RAJA_INLINE
57  __m256i createMask(camp::idx_t N) const
58  {
59  // Generate a mask
60  return _mm256_set_epi64x(N >= 4 ? -1 : 0, N >= 3 ? -1 : 0, N >= 2 ? -1 : 0,
61  N >= 1 ? -1 : 0);
62  }
63 
64  RAJA_INLINE
65  __m256i createStridedOffsets(camp::idx_t stride) const
66  {
67  // Generate a strided offset list
68  return _mm256_set_epi64x(3 * stride, 2 * stride, stride, 0);
69  }
70 
71  /*
72  * Use the packed-double permute function because there isn't one
73  * specifically for int64
74  *
75  * Just adds a bunch of casting, should be same cost
76  */
77  template<int perm>
78  RAJA_INLINE __m256i permute(__m256i x) const
79  {
80  return _mm256_castpd_si256(_mm256_permute_pd(_mm256_castsi256_pd(x), perm));
81  }
82 
83 public:
84  static constexpr camp::idx_t s_num_elem = 4;
85 
89  RAJA_INLINE
90  Register() : base_type(), m_value(_mm256_setzero_si256()) {}
91 
95  RAJA_INLINE
96  explicit Register(register_type const& c) : base_type(), m_value(c) {}
97 
101  RAJA_INLINE
102  Register(element_type x0, element_type x1, element_type x2, element_type x3)
103  : m_value(_mm256_set_epi64x(x3, x2, x1, x0))
104  {}
105 
109  RAJA_INLINE
110  Register(self_type const& c) : base_type(), m_value(c.m_value) {}
111 
115  RAJA_INLINE
116  self_type& operator=(self_type const& c)
117  {
118  m_value = c.m_value;
119  return *this;
120  }
121 
126  RAJA_INLINE
127  Register(element_type const& c) : m_value(_mm256_set1_epi64x(c)) {}
128 
133  RAJA_INLINE
134  self_type& load_packed(element_type const* ptr)
135  {
136  m_value = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(ptr));
137  return *this;
138  }
139 
145  RAJA_INLINE
146  self_type& load_packed_n(element_type const* ptr, camp::idx_t N)
147  {
148  m_value = _mm256_castpd_si256(_mm256_maskload_pd(
149  reinterpret_cast<double const*>(ptr), createMask(N)));
150  return *this;
151  }
152 
157  RAJA_INLINE
158  self_type& load_strided(element_type const* ptr, camp::idx_t stride)
159  {
160  for (camp::idx_t i = 0; i < 4; ++i)
161  {
162  m_value[i] = ptr[i * stride];
163  }
164  return *this;
165  }
166 
172  RAJA_INLINE
173  self_type& load_strided_n(element_type const* ptr,
174  camp::idx_t stride,
175  camp::idx_t N)
176  {
177  m_value = _mm256_setzero_si256();
178  for (camp::idx_t i = 0; i < N; ++i)
179  {
180  m_value[i] = ptr[i * stride];
181  }
182  return *this;
183  }
184 
189  RAJA_INLINE
190  self_type const& store_packed(element_type* ptr) const
191  {
192  _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), m_value);
193  return *this;
194  }
195 
200  RAJA_INLINE
201  self_type const& store_packed_n(element_type* ptr, camp::idx_t N) const
202  {
203  _mm256_maskstore_pd(reinterpret_cast<double*>(ptr), createMask(N),
204  reinterpret_cast<__m256d>(m_value));
205  return *this;
206  }
207 
212  RAJA_INLINE
213  self_type const& store_strided(element_type* ptr, camp::idx_t stride) const
214  {
215  for (camp::idx_t i = 0; i < 4; ++i)
216  {
217  ptr[i * stride] = m_value[i];
218  }
219  return *this;
220  }
221 
226  RAJA_INLINE
227  self_type const& store_strided_n(element_type* ptr,
228  camp::idx_t stride,
229  camp::idx_t N) const
230  {
231  for (camp::idx_t i = 0; i < N; ++i)
232  {
233  ptr[i * stride] = m_value[i];
234  }
235  return *this;
236  }
237 
243  RAJA_INLINE
244  element_type get(camp::idx_t i) const
245  {
246  // got to be a nicer way to do this!?!?
247  switch (i)
248  {
249  case 0:
250  return _mm256_extract_epi64(m_value, 0);
251  case 1:
252  return _mm256_extract_epi64(m_value, 1);
253  case 2:
254  return _mm256_extract_epi64(m_value, 2);
255  case 3:
256  return _mm256_extract_epi64(m_value, 3);
257  }
258  return 0;
259  }
260 
266  RAJA_INLINE
267  self_type& set(element_type value, camp::idx_t i)
268  {
269  // got to be a nicer way to do this!?!?
270  switch (i)
271  {
272  case 0:
273  m_value = _mm256_insert_epi64(m_value, value, 0);
274  break;
275  case 1:
276  m_value = _mm256_insert_epi64(m_value, value, 1);
277  break;
278  case 2:
279  m_value = _mm256_insert_epi64(m_value, value, 2);
280  break;
281  case 3:
282  m_value = _mm256_insert_epi64(m_value, value, 3);
283  break;
284  }
285 
286  return *this;
287  }
288 
290 
291  RAJA_INLINE
292  self_type& broadcast(element_type const& value)
293  {
294  m_value = _mm256_set1_epi64x(value);
295  return *this;
296  }
297 
299 
300  RAJA_INLINE
301  self_type& copy(self_type const& src)
302  {
303  m_value = src.m_value;
304  return *this;
305  }
306 
308 
309  RAJA_INLINE
310  self_type add(self_type const& b) const
311  {
312  // no 4-way 64-bit add, but there is a 2-way SSE... split and conquer
313 
314  // Low 128-bits - use _mm256_castsi256_si128???
315  auto low_a = _mm256_castsi256_si128(m_value);
316  auto low_b = _mm256_castsi256_si128(b.m_value);
317  auto res_low = _mm256_castsi128_si256(_mm_add_epi64(low_a, low_b));
318 
319  // Hi 128-bits
320  auto hi_a = _mm256_extractf128_si256(m_value, 1);
321  auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
322  auto res_hi = _mm_add_epi64(hi_a, hi_b);
323 
324  // Stitch back together
325  return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
326  }
327 
329 
330  RAJA_INLINE
331  self_type subtract(self_type const& b) const
332  {
333  // no 4-way 64-bit subtract, but there is a 2-way SSE... split and conquer
334 
335  // Low 128-bits - use _mm256_castsi256_si128???
336  auto low_a = _mm256_castsi256_si128(m_value);
337  auto low_b = _mm256_castsi256_si128(b.m_value);
338  auto res_low = _mm256_castsi128_si256(_mm_sub_epi64(low_a, low_b));
339 
340  // Hi 128-bits
341  auto hi_a = _mm256_extractf128_si256(m_value, 1);
342  auto hi_b = _mm256_extractf128_si256(b.m_value, 1);
343  auto res_hi = _mm_sub_epi64(hi_a, hi_b);
344 
345  // Stitch back together
346  return self_type(_mm256_insertf128_si256(res_low, res_hi, 1));
347  }
348 
350 
351  RAJA_INLINE
352  self_type multiply(self_type const& b) const
353  {
354  // AVX2 does not supply an int64_t multiply, so do it manually
355  return self_type(_mm256_set_epi64x(get(3) * b.get(3), get(2) * b.get(2),
356  get(1) * b.get(1), get(0) * b.get(0)));
357  }
358 
360 
361  RAJA_INLINE
362  self_type divide(self_type const& b) const
363  {
364  // AVX2 does not supply an integer divide, so do it manually
365  return self_type(_mm256_set_epi64x(get(3) / b.get(3), get(2) / b.get(2),
366  get(1) / b.get(1), get(0) / b.get(0)));
367  }
368 
370 
371  RAJA_INLINE
372  self_type divide_n(self_type const& b, camp::idx_t N) const
373  {
374  // AVX2 does not supply an integer divide, so do it manually
375  return self_type(_mm256_set_epi64x(
376  N >= 4 ? get(3) / b.get(3) : 0, N >= 3 ? get(2) / b.get(2) : 0,
377  N >= 2 ? get(1) / b.get(1) : 0, N >= 1 ? get(0) / b.get(0) : 0));
378  }
379 
384  RAJA_INLINE
385  element_type sum() const
386  {
387  // swap pairs and add
388  auto sh1 = permute<0x5>(m_value);
389 
390  // Add lower 128-bits
391  auto low_a = _mm256_castsi256_si128(m_value);
392  auto low_b = _mm256_castsi256_si128(sh1);
393  auto res_low = _mm_add_epi64(low_a, low_b);
394 
395  // Add upper 128-bits
396  auto hi_a = _mm256_extractf128_si256(m_value, 1);
397  auto hi_b = _mm256_extractf128_si256(sh1, 1);
398  auto res_hi = _mm_add_epi64(hi_a, hi_b);
399 
400  // Sum upper and lower
401  auto res = _mm_add_epi64(res_hi, res_low);
402 
403  // add lower and upper
404  return _mm_extract_epi64(res, 0);
405  }
406 
411  RAJA_INLINE
412  element_type max() const
413  {
414  // AVX2 does not supply an 64bit integer max!
415  auto red = get(0);
416 
417  auto v1 = get(1);
418  red = red < v1 ? v1 : red;
419 
420  auto v2 = get(2);
421  red = red < v2 ? v2 : red;
422 
423  auto v3 = get(3);
424  red = red < v3 ? v3 : red;
425 
426  return red;
427  }
428 
433  RAJA_INLINE
434  element_type max_n(camp::idx_t N) const
435  {
436  if (N <= 0 || N > 4)
437  {
439  }
440 
441  // AVX2 does not supply an 64bit integer max?!?
442  auto red = get(0);
443 
444  if (N > 1)
445  {
446  auto v1 = get(1);
447  red = red < v1 ? v1 : red;
448  }
449  if (N > 2)
450  {
451  auto v2 = get(2);
452  red = red < v2 ? v2 : red;
453  }
454  if (N > 3)
455  {
456  auto v3 = get(3);
457  red = red < v3 ? v3 : red;
458  }
459 
460  return red;
461  }
462 
467  RAJA_INLINE
468  self_type vmax(self_type a) const
469  {
470  return self_type(_mm256_set_epi64x(get(3) > a.get(3) ? get(3) : a.get(3),
471  get(2) > a.get(2) ? get(2) : a.get(2),
472  get(1) > a.get(1) ? get(1) : a.get(1),
473  get(0) > a.get(0) ? get(0) : a.get(0)));
474  }
475 
480  RAJA_INLINE
481  element_type min() const
482  {
483 
484  // AVX2 does not supply an 64bit integer max?!?
485  auto red = get(0);
486 
487  auto v1 = get(1);
488  red = red > v1 ? v1 : red;
489 
490  auto v2 = get(2);
491  red = red > v2 ? v2 : red;
492 
493  auto v3 = get(3);
494  red = red > v3 ? v3 : red;
495 
496  return red;
497  }
498 
503  RAJA_INLINE
504  element_type min_n(camp::idx_t N) const
505  {
506  if (N <= 0 || N > 4)
507  {
509  }
510 
511  // AVX2 does not supply an 64bit integer max?!?
512  auto red = get(0);
513 
514  if (N > 1)
515  {
516  auto v1 = get(1);
517  red = red > v1 ? v1 : red;
518  }
519  if (N > 2)
520  {
521  auto v2 = get(2);
522  red = red > v2 ? v2 : red;
523  }
524  if (N > 3)
525  {
526  auto v3 = get(3);
527  red = red > v3 ? v3 : red;
528  }
529 
530  return red;
531  }
532 
537  RAJA_INLINE
538  self_type vmin(self_type a) const
539  {
540  return self_type(_mm256_set_epi64x(get(3) < a.get(3) ? get(3) : a.get(3),
541  get(2) < a.get(2) ? get(2) : a.get(2),
542  get(1) < a.get(1) ? get(1) : a.get(1),
543  get(0) < a.get(0) ? get(0) : a.get(0)));
544  }
545 };
546 
547 
548 } // namespace expt
549 
550 } // namespace RAJA
551 
552 
553 #endif
554 
555 #endif //__AVX2__
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155