RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
hip_wave.hpp
Go to the documentation of this file.
1 
12 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
13 // Copyright (c) Lawrence Livermore National Security, LLC and other
14 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
15 // files for dates and other details. No copyright assignment is required
16 // to contribute to RAJA.
17 //
18 // SPDX-License-Identifier: (BSD-3-Clause)
19 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
20 
21 #ifndef RAJA_policy_tensor_arch_hip_hip_wave_register_HPP
22 #define RAJA_policy_tensor_arch_hip_hip_wave_register_HPP
23 
24 #include "RAJA/config.hpp"
25 
26 #if defined(RAJA_HIP_ACTIVE)
27 
28 #include "RAJA/util/macros.hpp"
30 #include "RAJA/util/macros.hpp"
31 #include "RAJA/util/Operators.hpp"
32 
34 
35 namespace RAJA
36 {
37 namespace expt
38 {
39 
40 
41 template<typename ELEMENT_TYPE>
42 class Register<ELEMENT_TYPE, hip_wave_register>
43  : public internal::expt::RegisterBase<
44  Register<ELEMENT_TYPE, hip_wave_register>>
45 {
46 public:
47  using base_type =
48  internal::expt::RegisterBase<Register<ELEMENT_TYPE, hip_wave_register>>;
49 
50  using register_policy = hip_wave_register;
51  using self_type = Register<ELEMENT_TYPE, hip_wave_register>;
52  using element_type = ELEMENT_TYPE;
53  using register_type = ELEMENT_TYPE;
54 
55  using int_vector_type = Register<int64_t, hip_wave_register>;
56 
57 
58 private:
59  element_type m_value;
60 
61 public:
62  static constexpr int s_num_elem = RAJA_HIP_WAVESIZE;
63 
67  RAJA_INLINE
68 
70  constexpr Register() : base_type(), m_value(0) {}
71 
75  RAJA_INLINE
76 
78  constexpr Register(element_type c) : base_type(), m_value(c) {}
79 
83  RAJA_INLINE
84 
86  constexpr Register(self_type const& c) : base_type(), m_value(c.m_value) {}
87 
91  RAJA_INLINE
92 
94  self_type& operator=(self_type const& c)
95  {
96  m_value = c.m_value;
97  return *this;
98  }
99 
100  RAJA_INLINE
101 
103  self_type& operator=(element_type c)
104  {
105  m_value = c;
106  return *this;
107  }
108 
112  RAJA_INLINE
113 
115  constexpr static int get_lane() { return threadIdx.x; }
116 
118 
119  RAJA_INLINE
120  constexpr element_type const& get_raw_value() const { return m_value; }
121 
123 
124  RAJA_INLINE
125  element_type& get_raw_value() { return m_value; }
126 
128 
129  RAJA_INLINE
130  static constexpr bool is_root() { return get_lane() == 0; }
131 
136  RAJA_INLINE
137 
139  self_type& load_packed(element_type const* ptr)
140  {
141 
142  auto lane = get_lane();
143 
144  m_value = ptr[lane];
145 
146  return *this;
147  }
148 
154  RAJA_INLINE
155 
157  self_type& load_packed_n(element_type const* ptr, int N)
158  {
159  auto lane = get_lane();
160  if (lane < N)
161  {
162  m_value = ptr[lane];
163  }
164  else
165  {
166  m_value = element_type(0);
167  }
168  return *this;
169  }
170 
175  RAJA_INLINE
176 
178  self_type& load_strided(element_type const* ptr, int stride)
179  {
180 
181  auto lane = get_lane();
182 
183  m_value = ptr[stride * lane];
184 
185  return *this;
186  }
187 
193  RAJA_INLINE
194 
196  self_type& load_strided_n(element_type const* ptr, int stride, int N)
197  {
198  auto lane = get_lane();
199 
200  if (lane < N)
201  {
202  m_value = ptr[stride * lane];
203  }
204  else
205  {
206  m_value = element_type(0);
207  }
208  return *this;
209  }
210 
220  RAJA_INLINE
221 
223  self_type& gather(element_type const* ptr, int_vector_type offsets)
224  {
225 
226  m_value = ptr[offsets.get_raw_value()];
227 
228  return *this;
229  }
230 
240  RAJA_INLINE
241 
243  self_type& gather_n(element_type const* ptr,
244  int_vector_type offsets,
245  camp::idx_t N)
246  {
247  if (get_lane() < N)
248  {
249  m_value = ptr[offsets.get_raw_value()];
250  }
251  else
252  {
253  m_value = element_type(0);
254  }
255 
256  return *this;
257  }
258 
269 
270  RAJA_INLINE
271  self_type& segmented_load(element_type const* ptr,
272  camp::idx_t segbits,
273  camp::idx_t stride_inner,
274  camp::idx_t stride_outer)
275  {
276  auto lane = get_lane();
277 
278  // compute segment and segment_size
279  auto seg = lane >> segbits;
280  auto i = lane & ((1 << segbits) - 1);
281 
282  m_value = ptr[seg * stride_outer + i * stride_inner];
283 
284  return *this;
285  }
286 
295 
296  RAJA_INLINE
297  self_type& segmented_load_nm(element_type const* ptr,
298  camp::idx_t segbits,
299  camp::idx_t stride_inner,
300  camp::idx_t stride_outer,
301  camp::idx_t num_inner,
302  camp::idx_t num_outer)
303  {
304  auto lane = get_lane();
305 
306  // compute segment and segment_size
307  auto seg = lane >> segbits;
308  auto i = lane & ((1 << segbits) - 1);
309 
310  if (seg >= num_outer || i >= num_inner)
311  {
312  m_value = element_type(0);
313  }
314  else
315  {
316  m_value = ptr[seg * stride_outer + i * stride_inner];
317  }
318 
319  return *this;
320  }
321 
326  RAJA_INLINE
327 
329  self_type const& store_packed(element_type* ptr) const
330  {
331 
332  auto lane = get_lane();
333 
334  ptr[lane] = m_value;
335 
336  return *this;
337  }
338 
343  RAJA_INLINE
344 
346  self_type const& store_packed_n(element_type* ptr, int N) const
347  {
348 
349  auto lane = get_lane();
350 
351  if (lane < N)
352  {
353  ptr[lane] = m_value;
354  }
355  return *this;
356  }
357 
362  RAJA_INLINE
363 
365  self_type const& store_strided(element_type* ptr, int stride) const
366  {
367 
368  auto lane = get_lane();
369 
370  ptr[lane * stride] = m_value;
371 
372  return *this;
373  }
374 
379  RAJA_INLINE
380 
382  self_type const& store_strided_n(element_type* ptr, int stride, int N) const
383  {
384 
385  auto lane = get_lane();
386 
387  if (lane < N)
388  {
389  ptr[lane * stride] = m_value;
390  }
391  return *this;
392  }
393 
403  template<typename T2>
404  RAJA_DEVICE RAJA_INLINE self_type const& scatter(element_type* ptr,
405  T2 const& offsets) const
406  {
407 
408  ptr[offsets.get_raw_value()] = m_value;
409 
410 
411  return *this;
412  }
413 
423  template<typename T2>
424  RAJA_DEVICE RAJA_INLINE self_type const& scatter_n(element_type* ptr,
425  T2 const& offsets,
426  camp::idx_t N) const
427  {
428  if (get_lane() < N)
429  {
430  ptr[offsets.get_raw_value()] = m_value;
431  }
432 
433  return *this;
434  }
435 
442 
443  RAJA_INLINE
444  self_type const& segmented_store(element_type* ptr,
445  camp::idx_t segbits,
446  camp::idx_t stride_inner,
447  camp::idx_t stride_outer) const
448  {
449  auto lane = get_lane();
450 
451  // compute segment and segment_size
452  auto seg = lane >> segbits;
453  auto i = lane & ((1 << segbits) - 1);
454 
455  ptr[seg * stride_outer + i * stride_inner] = m_value;
456 
457  return *this;
458  }
459 
466 
467  RAJA_INLINE
468  self_type const& segmented_store_nm(element_type* ptr,
469  camp::idx_t segbits,
470  camp::idx_t stride_inner,
471  camp::idx_t stride_outer,
472  camp::idx_t num_inner,
473  camp::idx_t num_outer) const
474  {
475  auto lane = get_lane();
476 
477  // compute segment and segment_size
478  auto seg = lane >> segbits;
479  auto i = lane & ((1 << segbits) - 1);
480 
481  if (seg >= num_outer || i >= num_inner)
482  {
483  // nop
484  }
485  else
486  {
487  ptr[seg * stride_outer + i * stride_inner] = m_value;
488  }
489 
490  return *this;
491  }
492 
498  constexpr RAJA_INLINE RAJA_DEVICE element_type get(int i) const
499  {
500  return hip::impl::shfl_sync(m_value, i);
501  }
502 
508  RAJA_INLINE
509 
511  self_type& set(element_type value, int i)
512  {
513  auto lane = get_lane();
514  if (lane == i)
515  {
516  m_value = value;
517  }
518  return *this;
519  }
520 
522 
523  RAJA_INLINE
524  self_type& broadcast(element_type const& a)
525  {
526  m_value = a;
527  return *this;
528  }
529 
534 
535  RAJA_INLINE
536  self_type get_and_broadcast(int i) const
537  {
538  self_type x;
539  x.m_value = hip::impl::shfl_sync(m_value, i);
540  return x;
541  }
542 
544 
545  RAJA_INLINE
546  self_type& copy(self_type const& src)
547  {
548  m_value = src.m_value;
549  return *this;
550  }
551 
553 
554  RAJA_INLINE
555  self_type add(self_type const& b) const
556  {
557  return self_type(m_value + b.m_value);
558  }
559 
561 
562  RAJA_INLINE
563  self_type subtract(self_type const& b) const
564  {
565  return self_type(m_value - b.m_value);
566  }
567 
569 
570  RAJA_INLINE
571  self_type multiply(self_type const& b) const
572  {
573  return self_type(m_value * b.m_value);
574  }
575 
577 
578  RAJA_INLINE
579  self_type divide(self_type const& b) const
580  {
581  return self_type(m_value / b.m_value);
582  }
583 
585 
586  RAJA_INLINE
587  self_type divide_n(self_type const& b, int N) const
588  {
589  return get_lane() < N ? self_type(m_value / b.m_value)
590  : self_type(element_type(0));
591  }
592 
596  template<typename RETURN_TYPE = self_type>
597  RAJA_DEVICE RAJA_INLINE
598  typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
599  RETURN_TYPE>::type
600  multiply_add(self_type const& b, self_type const& c) const
601  {
602  return self_type(fma(m_value, b.m_value, c.m_value));
603  }
604 
608  template<typename RETURN_TYPE = self_type>
609  RAJA_DEVICE RAJA_INLINE
610  typename std::enable_if<std::numeric_limits<element_type>::is_integer,
611  RETURN_TYPE>::type
612  multiply_add(self_type const& b, self_type const& c) const
613  {
614  return self_type(m_value * b.m_value + c.m_value);
615  }
616 
620  template<typename RETURN_TYPE = self_type>
621  RAJA_DEVICE RAJA_INLINE
622  typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
623  RETURN_TYPE>::type
624  multiply_subtract(self_type const& b, self_type const& c) const
625  {
626  return self_type(fma(m_value, b.m_value, -c.m_value));
627  }
628 
632  template<typename RETURN_TYPE = self_type>
633  RAJA_DEVICE RAJA_INLINE
634  typename std::enable_if<std::numeric_limits<element_type>::is_integer,
635  RETURN_TYPE>::type
636  multiply_subtract(self_type const& b, self_type const& c) const
637  {
638  return self_type(m_value * b.m_value - c.m_value);
639  }
640 
645  RAJA_INLINE
646 
648  element_type sum() const
649  {
650  // Allreduce sum
651  using combiner_t =
653 
654  return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(m_value);
655  }
656 
661  RAJA_INLINE
662 
664  element_type max() const
665  {
666  // Allreduce maximum
667  using combiner_t =
670 
671  return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(m_value);
672  }
673 
678  RAJA_INLINE
679 
681  element_type max_n(int N) const
682  {
683  // Allreduce maximum
684  using combiner_t =
687 
689  auto lane = get_lane();
690  auto value = lane < N ? m_value : ident;
691  return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(value);
692  }
693 
698  RAJA_INLINE
699 
701  self_type vmax(self_type a) const
702  {
703  return self_type {RAJA::max<element_type>(m_value, a.m_value)};
704  }
705 
710  RAJA_INLINE
711 
713  element_type min() const
714  {
715  // Allreduce minimum
716  using combiner_t =
719 
720  return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(m_value);
721  }
722 
727  RAJA_INLINE
728 
730  element_type min_n(int N) const
731  {
732  // Allreduce minimum
733  using combiner_t =
736 
738  auto lane = get_lane();
739  auto value = lane < N ? m_value : ident;
740  return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(value);
741  }
742 
747  RAJA_INLINE
748 
750  self_type vmin(self_type a) const
751  {
752  return self_type {RAJA::min<element_type>(m_value, a.m_value)};
753  }
754 
762  RAJA_INLINE
763 
765  static int_vector_type s_segmented_offsets(camp::idx_t segbits,
766  camp::idx_t stride_inner,
767  camp::idx_t stride_outer)
768  {
769  int_vector_type result;
770 
771  auto lane = get_lane();
772 
773  // compute segment and segment_size
774  auto seg = lane >> segbits;
775  auto i = lane & ((1 << segbits) - 1);
776 
777  result.get_raw_value() = seg * stride_outer + i * stride_inner;
778 
779  return result;
780  }
781 
815  RAJA_INLINE
816 
818  self_type segmented_sum_inner(camp::idx_t segbits,
819  camp::idx_t output_segment) const
820  {
821 
822  // First: tree reduce values within each segment
823  element_type x = m_value;
824  RAJA_UNROLL
825  for (int delta = 1; delta < 1 << segbits; delta = delta << 1)
826  {
827 
828  // tree shuffle
829  element_type y = hip::impl::shfl_sync(x, get_lane() + delta);
830 
831  // reduce
832  x += y;
833  }
834 
835  // Second: send result to output segment lanes
836  self_type result;
837  result.get_raw_value() = hip::impl::shfl_sync(x, get_lane() << segbits);
838 
839  // Third: mask off everything but output_segment
840  // this is because all output segments are valid at this point
841  static constexpr int log2_warp_size = RAJA::log2(RAJA_HIP_WAVESIZE);
842  int our_output_segment = get_lane() >> (log2_warp_size - segbits);
843  bool in_output_segment = our_output_segment == output_segment;
844  if (!in_output_segment)
845  {
846  result.get_raw_value() = 0;
847  }
848 
849  return result;
850  }
851 
883  RAJA_INLINE
884 
886  self_type segmented_sum_outer(camp::idx_t segbits,
887  camp::idx_t output_segment) const
888  {
889 
890  // First: tree reduce values within each segment
891  element_type x = m_value;
892  static constexpr int log2_warp_size = RAJA::log2(RAJA_HIP_WAVESIZE);
893  RAJA_UNROLL
894  for (int i = 0; i < log2_warp_size - segbits; ++i)
895  {
896 
897  // tree shuffle
898  int delta = s_num_elem >> (i + 1);
899  element_type y = hip::impl::shfl_sync(x, get_lane() + delta);
900 
901  // reduce
902  x += y;
903  }
904 
905  // Second: send result to output segment lanes
906  self_type result;
907  int get_from = get_lane() & ((1 << segbits) - 1);
908  result.get_raw_value() = hip::impl::shfl_sync(x, get_from);
909 
910  int mask = (get_lane() >> segbits) == output_segment;
911 
912 
913  // Third: mask off everything but output_segment
914  if (!mask)
915  {
916  result.get_raw_value() = 0;
917  }
918 
919  return result;
920  }
921 
922  RAJA_INLINE
923 
925  self_type segmented_divide_nm(self_type den,
926  camp::idx_t segbits,
927  camp::idx_t num_inner,
928  camp::idx_t num_outer) const
929  {
930  self_type result;
931 
932  auto lane = get_lane();
933 
934  // compute segment and segment_size
935  auto seg = lane >> segbits;
936  auto i = lane & ((1 << segbits) - 1);
937 
938  if (seg >= num_outer || i >= num_inner)
939  {
940  // nop
941  }
942  else
943  {
944  result.get_raw_value() = m_value / den.get_raw_value();
945  }
946 
947  return result;
948  }
949 
997  RAJA_INLINE
998 
1000  self_type segmented_broadcast_inner(camp::idx_t segbits,
1001  camp::idx_t input_segment) const
1002  {
1003  self_type result;
1004 
1005  camp::idx_t mask = (1 << segbits) - 1;
1006  camp::idx_t offset = input_segment << segbits;
1007 
1008 
1009  camp::idx_t i = (get_lane() & mask) + offset;
1010 
1011  result.get_raw_value() = hip::impl::shfl_sync(m_value, i);
1012 
1013 
1014  return result;
1015  }
1016 
1053  RAJA_INLINE
1054 
1055  RAJA_DEVICE
1056  self_type segmented_broadcast_outer(camp::idx_t segbits,
1057  camp::idx_t input_segment) const
1058  {
1059  self_type result;
1060 
1061  camp::idx_t offset = input_segment * (self_type::s_num_elem >> segbits);
1062 
1063  camp::idx_t i = (get_lane() >> segbits) + offset;
1064 
1065  result.get_raw_value() = hip::impl::shfl_sync(m_value, i);
1066 
1067  return result;
1068  }
1069 };
1070 
1071 
1072 } // namespace expt
1073 
1074 } // namespace RAJA
1075 
1076 
1077 #endif // HIP
1078 
1079 #endif // Guard
Header file for RAJA operator definitions.
RAJA header file defining SIMD/SIMT register operations.
Header file containing RAJA intrinsics templates for HIP execution.
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE T log2(T n) noexcept
evaluate log base 2 of n
Definition: math.hpp:40
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
Definition: Operators.hpp:580
Definition: Operators.hpp:559
Definition: reduce.hpp:70