RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
avx512_int64.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifdef __AVX512F__
21 
22 #ifndef RAJA_policy_vector_register_avx512_long_HPP
23 #define RAJA_policy_vector_register_avx512_long_HPP
24 
25 #include "RAJA/config.hpp"
26 #include "RAJA/util/macros.hpp"
28 
29 // Include SIMD intrinsics header file
30 #include <immintrin.h>
31 #include <cmath>
32 
33 namespace RAJA
34 {
35 namespace expt
36 {
37 template<>
38 class Register<int64_t, avx512_register>
39  : public internal::expt::RegisterBase<Register<int64_t, avx512_register>>
40 {
41 public:
42  using base_type =
43  internal::expt::RegisterBase<Register<int64_t, avx512_register>>;
44 
45  using register_policy = avx512_register;
46  using self_type = Register<int64_t, avx512_register>;
47  using element_type = int64_t;
48  using register_type = __m512i;
49 
50  using int_vector_type = Register<int64_t, avx512_register>;
51 
52 
53 private:
54  register_type m_value;
55 
56  RAJA_INLINE
57  __mmask8 createMask(camp::idx_t N) const
58  {
59  // Generate a mask
60  switch (N)
61  {
62  case 0:
63  return __mmask8(0x00);
64  case 1:
65  return __mmask8(0x01);
66  case 2:
67  return __mmask8(0x03);
68  case 3:
69  return __mmask8(0x07);
70  case 4:
71  return __mmask8(0x0F);
72  case 5:
73  return __mmask8(0x1F);
74  case 6:
75  return __mmask8(0x3F);
76  case 7:
77  return __mmask8(0x7F);
78  case 8:
79  return __mmask8(0xFF);
80  }
81  return __mmask8(0);
82  }
83 
84  RAJA_INLINE
85  __m512i createStridedOffsets(camp::idx_t stride) const
86  {
87  // Generate a strided offset list
88  auto vstride = _mm512_set1_epi64(stride);
89  auto vseq = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0);
90  return _mm512_mullo_epi64(vstride, vseq);
91  }
92 
93 public:
94  static constexpr camp::idx_t s_num_elem = 8;
95 
99  // AVX512F
100  RAJA_INLINE
101  Register() : base_type(), m_value(_mm512_setzero_epi32()) {}
102 
106  RAJA_INLINE
107  explicit Register(register_type const& c) : base_type(), m_value(c) {}
108 
112  RAJA_INLINE
113  Register(self_type const& c) : base_type(), m_value(c.m_value) {}
114 
118  RAJA_INLINE
119  self_type& operator=(self_type const& c)
120  {
121  m_value = c.m_value;
122  return *this;
123  }
124 
129  // AVX512F
130  RAJA_INLINE
131  Register(element_type const& c) : base_type(), m_value(_mm512_set1_epi64(c))
132  {}
133 
138  RAJA_INLINE
139  self_type& load_packed(element_type const* ptr)
140  {
141  // AVX512F
142 #if (defined(__GNUC__) && ((__GNUC__ >= 7) && (__GNUC__ <= 10))) || \
143  (!defined(SYCL_LANGUAGE_VERSION) && \
144  defined(__INTEL_LLVM_COMPILER)) // Check for oneapi's icpx.
145  m_value = _mm512_maskz_loadu_epi64(
146  ~0,
147  ptr); // May cause slowdown due to looping over 8 bytes, one at a time.
148 #else
149  m_value =
150  _mm512_loadu_epi64(ptr); // GNU 7-10 are missing this instruction, as
151  // is icpx as of version 2022.2.
152 #endif
153  return *this;
154  }
155 
161  RAJA_INLINE
162  self_type& load_packed_n(element_type const* ptr, camp::idx_t N)
163  {
164  // AVX512F
165  m_value =
166  _mm512_mask_loadu_epi64(_mm512_setzero_epi32(), createMask(N), ptr);
167  return *this;
168  }
169 
174  RAJA_INLINE
175  self_type& load_strided(element_type const* ptr, camp::idx_t stride)
176  {
177  // AVX512F
178  m_value = _mm512_i64gather_epi64(createStridedOffsets(stride), ptr,
179  sizeof(element_type));
180  return *this;
181  }
182 
188  RAJA_INLINE
189  self_type& load_strided_n(element_type const* ptr,
190  camp::idx_t stride,
191  camp::idx_t N)
192  {
193  // AVX512F
194  m_value = _mm512_mask_i64gather_epi64(_mm512_setzero_epi32(), createMask(N),
195  createStridedOffsets(stride), ptr,
196  sizeof(element_type));
197  return *this;
198  }
199 
204  RAJA_INLINE
205  self_type const& store_packed(element_type* ptr) const
206  {
207  // AVX512F
208 #if (defined(__GNUC__) && ((__GNUC__ >= 7) && (__GNUC__ <= 10))) || \
209  (!defined(SYCL_LANGUAGE_VERSION) && \
210  defined(__INTEL_LLVM_COMPILER)) // Check for oneapi's icpx.
211  _mm512_mask_storeu_epi64(ptr, ~0,
212  m_value); // May cause slowdown due to looping
213  // over 8 bytes, one at a time.
214 #else
215  _mm512_storeu_epi64(ptr,
216  m_value); // GNU 7-10 are missing this instruction, as
217  // is icpx as of version 2022.2.
218 #endif
219  return *this;
220  }
221 
226  RAJA_INLINE
227  self_type const& store_packed_n(element_type* ptr, camp::idx_t N) const
228  {
229  // AVX512F
230  _mm512_mask_storeu_epi64(ptr, createMask(N), m_value);
231  return *this;
232  }
233 
238  RAJA_INLINE
239  self_type const& store_strided(element_type* ptr, camp::idx_t stride) const
240  {
241  // AVX512F
242  _mm512_i64scatter_epi64(ptr, createStridedOffsets(stride), m_value,
243  sizeof(element_type));
244  return *this;
245  }
246 
251  RAJA_INLINE
252  self_type const& store_strided_n(element_type* ptr,
253  camp::idx_t stride,
254  camp::idx_t N) const
255  {
256  // AVX512F
257  _mm512_mask_i64scatter_epi64(ptr, createMask(N),
258  createStridedOffsets(stride), m_value,
259  sizeof(element_type));
260  return *this;
261  }
262 
268  RAJA_INLINE
269  element_type get(camp::idx_t i) const { return m_value[i]; }
270 
276  RAJA_INLINE
277  self_type& set(element_type value, camp::idx_t i)
278  {
279  m_value[i] = value;
280  return *this;
281  }
282 
284 
285  RAJA_INLINE
286  self_type& broadcast(element_type const& value)
287  {
288  m_value = _mm512_set1_epi64(value);
289  return *this;
290  }
291 
293 
294  RAJA_INLINE
295  self_type& copy(self_type const& src)
296  {
297  m_value = src.m_value;
298  return *this;
299  }
300 
302 
303  RAJA_INLINE
304  self_type add(self_type const& b) const
305  {
306  return self_type(_mm512_add_epi64(m_value, b.m_value));
307  }
308 
310 
311  RAJA_INLINE
312  self_type subtract(self_type const& b) const
313  {
314  return self_type(_mm512_sub_epi64(m_value, b.m_value));
315  }
316 
318 
319  RAJA_INLINE
320  self_type multiply(self_type const& b) const
321  {
322  return self_type(_mm512_mullo_epi64(m_value, b.m_value));
323  }
324 
326 
327  RAJA_INLINE
328  self_type divide(self_type const& b) const
329  {
330  // AVX512 does not supply an integer divide, so do it manually
331  return self_type(_mm512_set_epi64(get(7) / b.get(7), get(6) / b.get(6),
332  get(5) / b.get(5), get(4) / b.get(4),
333  get(3) / b.get(3), get(2) / b.get(2),
334  get(1) / b.get(1), get(0) / b.get(0)));
335  }
336 
338 
339  RAJA_INLINE
340  self_type divide_n(self_type const& b, camp::idx_t N) const
341  {
342  // AVX512 does not supply an integer divide, so do it manually
343  return self_type(_mm512_set_epi64(
344  N >= 8 ? get(7) / b.get(7) : 0, N >= 7 ? get(6) / b.get(6) : 0,
345  N >= 6 ? get(5) / b.get(5) : 0, N >= 5 ? get(4) / b.get(4) : 0,
346  N >= 4 ? get(3) / b.get(3) : 0, N >= 3 ? get(2) / b.get(2) : 0,
347  N >= 2 ? get(1) / b.get(1) : 0, N >= 1 ? get(0) / b.get(0) : 0));
348  }
349 
354  RAJA_INLINE
355  element_type sum() const { return _mm512_reduce_add_epi64(m_value); }
356 
361  RAJA_INLINE
362  element_type max() const { return _mm512_reduce_max_epi64(m_value); }
363 
368  RAJA_INLINE
369  element_type max_n(camp::idx_t N) const
370  {
371  return _mm512_mask_reduce_max_epi64(createMask(N), m_value);
372  }
373 
378  RAJA_INLINE
379  self_type vmax(self_type a) const
380  {
381  return self_type(_mm512_max_epi64(m_value, a.m_value));
382  }
383 
388  RAJA_INLINE
389  element_type min() const { return _mm512_reduce_min_epi64(m_value); }
390 
395  RAJA_INLINE
396  element_type min_n(camp::idx_t N) const
397  {
398  return _mm512_mask_reduce_min_epi64(createMask(N), m_value);
399  }
400 
405  RAJA_INLINE
406  self_type vmin(self_type a) const
407  {
408  return self_type(_mm512_min_epi64(m_value, a.m_value));
409  }
410 };
411 
412 
413 } // namespace expt
414 
415 } // namespace RAJA
416 
417 
418 #endif
419 
420 #endif //__AVX512F__
RAJA header file defining SIMD/SIMT register operations.
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155