RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
cuda_warp.hpp
Go to the documentation of this file.
1 
12 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
13 // Copyright (c) Lawrence Livermore National Security, LLC and other
14 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
15 // files for dates and other details. No copyright assignment is required
16 // to contribute to RAJA.
17 //
18 // SPDX-License-Identifier: (BSD-3-Clause)
19 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
20 
21 #ifndef RAJA_policy_tensor_arch_cuda_cuda_warp_register_HPP
22 #define RAJA_policy_tensor_arch_cuda_cuda_warp_register_HPP
23 
24 #include "RAJA/config.hpp"
25 
26 #if defined(RAJA_CUDA_ACTIVE)
27 
28 #include "RAJA/util/macros.hpp"
30 #include "RAJA/util/macros.hpp"
31 #include "RAJA/util/Operators.hpp"
32 
34 
35 namespace RAJA
36 {
37 namespace expt
38 {
39 
40 template<typename ELEMENT_TYPE>
41 class Register<ELEMENT_TYPE, cuda_warp_register>
42  : public internal::expt::RegisterBase<
43  Register<ELEMENT_TYPE, cuda_warp_register>>
44 {
45 public:
46  using base_type =
47  internal::expt::RegisterBase<Register<ELEMENT_TYPE, cuda_warp_register>>;
48 
49  using register_policy = cuda_warp_register;
50  using self_type = Register<ELEMENT_TYPE, cuda_warp_register>;
51  using element_type = ELEMENT_TYPE;
52  using register_type = ELEMENT_TYPE;
53 
54  using int_vector_type = Register<int64_t, cuda_warp_register>;
55 
56 
57 private:
58  element_type m_value;
59 
60 public:
61  static constexpr int s_num_elem = RAJA_CUDA_WARPSIZE;
62 
66  RAJA_INLINE
67 
69  constexpr Register() : base_type(), m_value(0) {}
70 
74  RAJA_INLINE
75 
77  constexpr Register(element_type c) : base_type(), m_value(c) {}
78 
82  RAJA_INLINE
83 
85  constexpr Register(self_type const& c) : base_type(), m_value(c.m_value) {}
86 
90  RAJA_INLINE
91 
93  self_type& operator=(self_type const& c)
94  {
95  m_value = c.m_value;
96  return *this;
97  }
98 
99  RAJA_INLINE
100 
102  self_type& operator=(element_type c)
103  {
104  m_value = c;
105  return *this;
106  }
107 
111  RAJA_INLINE
112 
114  constexpr static int get_lane() { return threadIdx.x; }
115 
117 
118  RAJA_INLINE
119  constexpr element_type const& get_raw_value() const { return m_value; }
120 
122 
123  RAJA_INLINE
124  element_type& get_raw_value() { return m_value; }
125 
127 
128  RAJA_INLINE
129  static constexpr bool is_root() { return get_lane() == 0; }
130 
135  RAJA_INLINE
136 
138  self_type& load_packed(element_type const* ptr)
139  {
140 
141  auto lane = get_lane();
142 
143  m_value = ptr[lane];
144 
145  return *this;
146  }
147 
153  RAJA_INLINE
154 
156  self_type& load_packed_n(element_type const* ptr, int N)
157  {
158  auto lane = get_lane();
159  if (lane < N)
160  {
161  m_value = ptr[lane];
162  }
163  else
164  {
165  m_value = element_type(0);
166  }
167  return *this;
168  }
169 
174  RAJA_INLINE
175 
177  self_type& load_strided(element_type const* ptr, int stride)
178  {
179 
180  auto lane = get_lane();
181 
182  m_value = ptr[stride * lane];
183 
184  return *this;
185  }
186 
192  RAJA_INLINE
193 
195  self_type& load_strided_n(element_type const* ptr, int stride, int N)
196  {
197  auto lane = get_lane();
198 
199  if (lane < N)
200  {
201  m_value = ptr[stride * lane];
202  }
203  else
204  {
205  m_value = element_type(0);
206  }
207  return *this;
208  }
209 
219  RAJA_INLINE
220 
222  self_type& gather(element_type const* ptr, int_vector_type offsets)
223  {
224 
225  m_value = ptr[offsets.get_raw_value()];
226 
227  return *this;
228  }
229 
239  RAJA_INLINE
240 
242  self_type& gather_n(element_type const* ptr,
243  int_vector_type offsets,
244  camp::idx_t N)
245  {
246  if (get_lane() < N)
247  {
248  m_value = ptr[offsets.get_raw_value()];
249  }
250  else
251  {
252  m_value = element_type(0);
253  }
254 
255  return *this;
256  }
257 
268 
269  RAJA_INLINE
270  self_type& segmented_load(element_type const* ptr,
271  camp::idx_t segbits,
272  camp::idx_t stride_inner,
273  camp::idx_t stride_outer)
274  {
275  auto lane = get_lane();
276 
277  // compute segment and segment_size
278  auto seg = lane >> segbits;
279  auto i = lane & ((1 << segbits) - 1);
280 
281  m_value = ptr[seg * stride_outer + i * stride_inner];
282 
283  return *this;
284  }
285 
294 
295  RAJA_INLINE
296  self_type& segmented_load_nm(element_type const* ptr,
297  camp::idx_t segbits,
298  camp::idx_t stride_inner,
299  camp::idx_t stride_outer,
300  camp::idx_t num_inner,
301  camp::idx_t num_outer)
302  {
303  auto lane = get_lane();
304 
305  // compute segment and segment_size
306  auto seg = lane >> segbits;
307  auto i = lane & ((1 << segbits) - 1);
308 
309  if (seg >= num_outer || i >= num_inner)
310  {
311  m_value = element_type(0);
312  }
313  else
314  {
315  m_value = ptr[seg * stride_outer + i * stride_inner];
316  }
317 
318  return *this;
319  }
320 
325  RAJA_INLINE
326 
328  self_type const& store_packed(element_type* ptr) const
329  {
330 
331  auto lane = get_lane();
332 
333  ptr[lane] = m_value;
334 
335  return *this;
336  }
337 
342  RAJA_INLINE
343 
345  self_type const& store_packed_n(element_type* ptr, int N) const
346  {
347 
348  auto lane = get_lane();
349 
350  if (lane < N)
351  {
352  ptr[lane] = m_value;
353  }
354  return *this;
355  }
356 
361  RAJA_INLINE
362 
364  self_type const& store_strided(element_type* ptr, int stride) const
365  {
366 
367  auto lane = get_lane();
368 
369  ptr[lane * stride] = m_value;
370 
371  return *this;
372  }
373 
378  RAJA_INLINE
379 
381  self_type const& store_strided_n(element_type* ptr, int stride, int N) const
382  {
383 
384  auto lane = get_lane();
385 
386  if (lane < N)
387  {
388  ptr[lane * stride] = m_value;
389  }
390  return *this;
391  }
392 
402  template<typename T2>
403  RAJA_DEVICE RAJA_INLINE self_type const& scatter(element_type* ptr,
404  T2 const& offsets) const
405  {
406 
407  ptr[offsets.get_raw_value()] = m_value;
408 
409  return *this;
410  }
411 
421  template<typename T2>
422  RAJA_DEVICE RAJA_INLINE self_type const& scatter_n(element_type* ptr,
423  T2 const& offsets,
424  camp::idx_t N) const
425  {
426  if (get_lane() < N)
427  {
428  ptr[offsets.get_raw_value()] = m_value;
429  }
430 
431  return *this;
432  }
433 
440 
441  RAJA_INLINE
442  self_type const& segmented_store(element_type* ptr,
443  camp::idx_t segbits,
444  camp::idx_t stride_inner,
445  camp::idx_t stride_outer) const
446  {
447  auto lane = get_lane();
448 
449  // compute segment and segment_size
450  auto seg = lane >> segbits;
451  auto i = lane & ((1 << segbits) - 1);
452 
453  ptr[seg * stride_outer + i * stride_inner] = m_value;
454 
455  return *this;
456  }
457 
464 
465  RAJA_INLINE
466  self_type const& segmented_store_nm(element_type* ptr,
467  camp::idx_t segbits,
468  camp::idx_t stride_inner,
469  camp::idx_t stride_outer,
470  camp::idx_t num_inner,
471  camp::idx_t num_outer) const
472  {
473  auto lane = get_lane();
474 
475  // compute segment and segment_size
476  auto seg = lane >> segbits;
477  auto i = lane & ((1 << segbits) - 1);
478 
479  if (seg >= num_outer || i >= num_inner)
480  {
481  // nop
482  }
483  else
484  {
485  ptr[seg * stride_outer + i * stride_inner] = m_value;
486  }
487 
488  return *this;
489  }
490 
496  constexpr RAJA_INLINE RAJA_DEVICE element_type get(int i) const
497  {
498  return __shfl_sync(0xffffffff, m_value, i);
499  }
500 
506  RAJA_INLINE
507 
509  self_type& set(element_type value, int i)
510  {
511  auto lane = get_lane();
512  if (lane == i)
513  {
514  m_value = value;
515  }
516  return *this;
517  }
518 
520 
521  RAJA_INLINE
522  self_type& broadcast(element_type const& a)
523  {
524  m_value = a;
525  return *this;
526  }
527 
532 
533  RAJA_INLINE
534  self_type get_and_broadcast(int i) const
535  {
536  self_type x;
537  x.m_value = __shfl_sync(0xffffffff, m_value, i);
538  return x;
539  }
540 
542 
543  RAJA_INLINE
544  self_type& copy(self_type const& src)
545  {
546  m_value = src.m_value;
547  return *this;
548  }
549 
551 
552  RAJA_INLINE
553  self_type add(self_type const& b) const
554  {
555  return self_type(m_value + b.m_value);
556  }
557 
559 
560  RAJA_INLINE
561  self_type subtract(self_type const& b) const
562  {
563  return self_type(m_value - b.m_value);
564  }
565 
567 
568  RAJA_INLINE
569  self_type multiply(self_type const& b) const
570  {
571  return self_type(m_value * b.m_value);
572  }
573 
575 
576  RAJA_INLINE
577  self_type divide(self_type const& b) const
578  {
579  return self_type(m_value / b.m_value);
580  }
581 
583 
584  RAJA_INLINE
585  self_type divide_n(self_type const& b, int N) const
586  {
587  return get_lane() < N ? self_type(m_value / b.m_value)
588  : self_type(element_type(0));
589  }
590 
594  template<typename RETURN_TYPE = self_type>
595  RAJA_DEVICE RAJA_INLINE
596  typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
597  RETURN_TYPE>::type
598  multiply_add(self_type const& b, self_type const& c) const
599  {
600  return self_type(fma(m_value, b.m_value, c.m_value));
601  }
602 
606  template<typename RETURN_TYPE = self_type>
607  RAJA_DEVICE RAJA_INLINE
608  typename std::enable_if<std::numeric_limits<element_type>::is_integer,
609  RETURN_TYPE>::type
610  multiply_add(self_type const& b, self_type const& c) const
611  {
612  return self_type(m_value * b.m_value + c.m_value);
613  }
614 
618  template<typename RETURN_TYPE = self_type>
619  RAJA_DEVICE RAJA_INLINE
620  typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
621  RETURN_TYPE>::type
622  multiply_subtract(self_type const& b, self_type const& c) const
623  {
624  return self_type(fma(m_value, b.m_value, -c.m_value));
625  }
626 
630  template<typename RETURN_TYPE = self_type>
631  RAJA_DEVICE RAJA_INLINE
632  typename std::enable_if<std::numeric_limits<element_type>::is_integer,
633  RETURN_TYPE>::type
634  multiply_subtract(self_type const& b, self_type const& c) const
635  {
636  return self_type(m_value * b.m_value - c.m_value);
637  }
638 
643  RAJA_INLINE
644 
646  element_type sum() const
647  {
648  // Allreduce sum
649  using combiner_t =
651 
652  return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(m_value);
653  }
654 
659  RAJA_INLINE
660 
662  element_type max() const
663  {
664  // Allreduce maximum
665  using combiner_t =
668 
669  return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(m_value);
670  }
671 
676  RAJA_INLINE
677 
679  element_type max_n(int N) const
680  {
681  // Allreduce maximum
682  using combiner_t =
685 
687  auto lane = get_lane();
688  auto value = lane < N ? m_value : ident;
689  return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(value);
690  }
691 
696  RAJA_INLINE
697 
699  self_type vmax(self_type a) const
700  {
701  return self_type {RAJA::max<element_type>(m_value, a.m_value)};
702  }
703 
708  RAJA_INLINE
709 
711  element_type min() const
712  {
713  // Allreduce minimum
714  using combiner_t =
717 
718  return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(m_value);
719  }
720 
725  RAJA_INLINE
726 
728  element_type min_n(int N) const
729  {
730  // Allreduce minimum
731  using combiner_t =
734 
736  auto lane = get_lane();
737  auto value = lane < N ? m_value : ident;
738  return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(value);
739  }
740 
745  RAJA_INLINE
746 
748  self_type vmin(self_type a) const
749  {
750  return self_type {RAJA::min<element_type>(m_value, a.m_value)};
751  }
752 
760  RAJA_INLINE
761 
763  static int_vector_type s_segmented_offsets(camp::idx_t segbits,
764  camp::idx_t stride_inner,
765  camp::idx_t stride_outer)
766  {
767  int_vector_type result;
768 
769  auto lane = get_lane();
770 
771  // compute segment and segment_size
772  auto seg = lane >> segbits;
773  auto i = lane & ((1 << segbits) - 1);
774 
775  result.get_raw_value() = seg * stride_outer + i * stride_inner;
776 
777  return result;
778  }
779 
813  RAJA_INLINE
814 
816  self_type segmented_sum_inner(camp::idx_t segbits,
817  camp::idx_t output_segment) const
818  {
819 
820  // First: tree reduce values within each segment
821  element_type x = m_value;
822  RAJA_UNROLL
823  for (int delta = 1; delta < 1 << segbits; delta = delta << 1)
824  {
825 
826  // tree shuffle
827  element_type y = __shfl_sync(0xffffffff, x, get_lane() + delta);
828 
829  // reduce
830  x += y;
831  }
832 
833  // Second: send result to output segment lanes
834  self_type result;
835  result.get_raw_value() = __shfl_sync(0xffffffff, x, get_lane() << segbits);
836 
837  // Third: mask off everything but output_segment
838  // this is because all output segments are valid at this point
839  static constexpr int log2_warp_size = RAJA::log2(RAJA_CUDA_WARPSIZE);
840  int our_output_segment = get_lane() >> (log2_warp_size - segbits);
841  bool in_output_segment = our_output_segment == output_segment;
842  if (!in_output_segment)
843  {
844  result.get_raw_value() = 0;
845  }
846 
847  return result;
848  }
849 
881  RAJA_INLINE
882 
884  self_type segmented_sum_outer(camp::idx_t segbits,
885  camp::idx_t output_segment) const
886  {
887 
888  // First: tree reduce values within each segment
889  element_type x = m_value;
890  static constexpr int log2_warp_size = RAJA::log2(RAJA_CUDA_WARPSIZE);
891  RAJA_UNROLL
892  for (int i = 0; i < log2_warp_size - segbits; ++i)
893  {
894 
895  // tree shuffle
896  int delta = s_num_elem >> (i + 1);
897  element_type y = __shfl_sync(0xffffffff, x, get_lane() + delta);
898 
899  // reduce
900  x += y;
901  }
902 
903  // Second: send result to output segment lanes
904  self_type result;
905  int get_from = get_lane() & ((1 << segbits) - 1);
906  result.get_raw_value() = __shfl_sync(0xffffffff, x, get_from);
907 
908  int mask = (get_lane() >> segbits) == output_segment;
909 
910 
911  // Third: mask off everything but output_segment
912  if (!mask)
913  {
914  result.get_raw_value() = 0;
915  }
916 
917  return result;
918  }
919 
920  RAJA_INLINE
921 
923  self_type segmented_divide_nm(self_type den,
924  camp::idx_t segbits,
925  camp::idx_t num_inner,
926  camp::idx_t num_outer) const
927  {
928  self_type result;
929 
930  auto lane = get_lane();
931 
932  // compute segment and segment_size
933  auto seg = lane >> segbits;
934  auto i = lane & ((1 << segbits) - 1);
935 
936  if (seg >= num_outer || i >= num_inner)
937  {
938  // nop
939  }
940  else
941  {
942  result.get_raw_value() = m_value / den.get_raw_value();
943  }
944 
945  return result;
946  }
947 
995  RAJA_INLINE
996 
998  self_type segmented_broadcast_inner(camp::idx_t segbits,
999  camp::idx_t input_segment) const
1000  {
1001  self_type result;
1002 
1003  camp::idx_t mask = (1 << segbits) - 1;
1004  camp::idx_t offset = input_segment << segbits;
1005 
1006 
1007  camp::idx_t i = (get_lane() & mask) + offset;
1008 
1009  result.get_raw_value() = __shfl_sync(0xffffffff, m_value, i);
1010 
1011 
1012  return result;
1013  }
1014 
1051  RAJA_INLINE
1052 
1053  RAJA_DEVICE
1054  self_type segmented_broadcast_outer(camp::idx_t segbits,
1055  camp::idx_t input_segment) const
1056  {
1057  self_type result;
1058 
1059  camp::idx_t offset = input_segment * (self_type::s_num_elem >> segbits);
1060 
1061  camp::idx_t i = (get_lane() >> segbits) + offset;
1062 
1063  result.get_raw_value() = __shfl_sync(0xffffffff, m_value, i);
1064 
1065  return result;
1066  }
1067 };
1068 
1069 
1070 } // namespace expt
1071 
1072 } // namespace RAJA
1073 
1074 
1075 #endif // CUDA
1076 
1077 #endif // Guard
Header file for RAJA operator definitions.
RAJA header file defining SIMD/SIMT register operations.
Header file containing RAJA intrinsics templates for CUDA execution.
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE T log2(T n) noexcept
evaluate log base 2 of n
Definition: math.hpp:40
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
Definition: Operators.hpp:580
Definition: Operators.hpp:559
Definition: reduce.hpp:70