21 #ifndef RAJA_policy_tensor_arch_hip_hip_wave_register_HPP
22 #define RAJA_policy_tensor_arch_hip_hip_wave_register_HPP
24 #include "RAJA/config.hpp"
26 #if defined(RAJA_HIP_ACTIVE)
41 template<
typename ELEMENT_TYPE>
42 class Register<ELEMENT_TYPE, hip_wave_register>
43 :
public internal::expt::RegisterBase<
44 Register<ELEMENT_TYPE, hip_wave_register>>
48 internal::expt::RegisterBase<Register<ELEMENT_TYPE, hip_wave_register>>;
50 using register_policy = hip_wave_register;
51 using self_type = Register<ELEMENT_TYPE, hip_wave_register>;
52 using element_type = ELEMENT_TYPE;
53 using register_type = ELEMENT_TYPE;
55 using int_vector_type = Register<int64_t, hip_wave_register>;
62 static constexpr
int s_num_elem = RAJA_HIP_WAVESIZE;
70 constexpr Register() : base_type(), m_value(0) {}
78 constexpr Register(element_type c) : base_type(), m_value(c) {}
86 constexpr Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
94 self_type& operator=(self_type
const& c)
103 self_type& operator=(element_type c)
115 constexpr
static int get_lane() {
return threadIdx.x; }
120 constexpr element_type
const& get_raw_value()
const {
return m_value; }
125 element_type& get_raw_value() {
return m_value; }
130 static constexpr
bool is_root() {
return get_lane() == 0; }
139 self_type& load_packed(element_type
const* ptr)
142 auto lane = get_lane();
157 self_type& load_packed_n(element_type
const* ptr,
int N)
159 auto lane = get_lane();
166 m_value = element_type(0);
178 self_type& load_strided(element_type
const* ptr,
int stride)
181 auto lane = get_lane();
183 m_value = ptr[stride * lane];
196 self_type& load_strided_n(element_type
const* ptr,
int stride,
int N)
198 auto lane = get_lane();
202 m_value = ptr[stride * lane];
206 m_value = element_type(0);
223 self_type& gather(element_type
const* ptr, int_vector_type offsets)
226 m_value = ptr[offsets.get_raw_value()];
243 self_type& gather_n(element_type
const* ptr,
244 int_vector_type offsets,
249 m_value = ptr[offsets.get_raw_value()];
253 m_value = element_type(0);
271 self_type& segmented_load(element_type
const* ptr,
273 camp::idx_t stride_inner,
274 camp::idx_t stride_outer)
276 auto lane = get_lane();
279 auto seg = lane >> segbits;
280 auto i = lane & ((1 << segbits) - 1);
282 m_value = ptr[seg * stride_outer + i * stride_inner];
297 self_type& segmented_load_nm(element_type
const* ptr,
299 camp::idx_t stride_inner,
300 camp::idx_t stride_outer,
301 camp::idx_t num_inner,
302 camp::idx_t num_outer)
304 auto lane = get_lane();
307 auto seg = lane >> segbits;
308 auto i = lane & ((1 << segbits) - 1);
310 if (seg >= num_outer || i >= num_inner)
312 m_value = element_type(0);
316 m_value = ptr[seg * stride_outer + i * stride_inner];
329 self_type
const& store_packed(element_type* ptr)
const
332 auto lane = get_lane();
346 self_type
const& store_packed_n(element_type* ptr,
int N)
const
349 auto lane = get_lane();
365 self_type
const& store_strided(element_type* ptr,
int stride)
const
368 auto lane = get_lane();
370 ptr[lane * stride] = m_value;
382 self_type
const& store_strided_n(element_type* ptr,
int stride,
int N)
const
385 auto lane = get_lane();
389 ptr[lane * stride] = m_value;
403 template<
typename T2>
404 RAJA_DEVICE RAJA_INLINE self_type
const& scatter(element_type* ptr,
405 T2
const& offsets)
const
408 ptr[offsets.get_raw_value()] = m_value;
423 template<
typename T2>
424 RAJA_DEVICE RAJA_INLINE self_type
const& scatter_n(element_type* ptr,
430 ptr[offsets.get_raw_value()] = m_value;
444 self_type
const& segmented_store(element_type* ptr,
446 camp::idx_t stride_inner,
447 camp::idx_t stride_outer)
const
449 auto lane = get_lane();
452 auto seg = lane >> segbits;
453 auto i = lane & ((1 << segbits) - 1);
455 ptr[seg * stride_outer + i * stride_inner] = m_value;
468 self_type
const& segmented_store_nm(element_type* ptr,
470 camp::idx_t stride_inner,
471 camp::idx_t stride_outer,
472 camp::idx_t num_inner,
473 camp::idx_t num_outer)
const
475 auto lane = get_lane();
478 auto seg = lane >> segbits;
479 auto i = lane & ((1 << segbits) - 1);
481 if (seg >= num_outer || i >= num_inner)
487 ptr[seg * stride_outer + i * stride_inner] = m_value;
500 return hip::impl::shfl_sync(m_value, i);
511 self_type& set(element_type value,
int i)
513 auto lane = get_lane();
524 self_type& broadcast(element_type
const& a)
536 self_type get_and_broadcast(
int i)
const
539 x.m_value = hip::impl::shfl_sync(m_value, i);
546 self_type& copy(self_type
const& src)
548 m_value = src.m_value;
555 self_type add(self_type
const& b)
const
557 return self_type(m_value + b.m_value);
563 self_type subtract(self_type
const& b)
const
565 return self_type(m_value - b.m_value);
571 self_type multiply(self_type
const& b)
const
573 return self_type(m_value * b.m_value);
579 self_type divide(self_type
const& b)
const
581 return self_type(m_value / b.m_value);
587 self_type divide_n(self_type
const& b,
int N)
const
589 return get_lane() < N ? self_type(m_value / b.m_value)
590 : self_type(element_type(0));
596 template<
typename RETURN_TYPE = self_type>
598 typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
600 multiply_add(self_type
const& b, self_type
const& c)
const
602 return self_type(fma(m_value, b.m_value, c.m_value));
608 template<
typename RETURN_TYPE = self_type>
610 typename std::enable_if<std::numeric_limits<element_type>::is_integer,
612 multiply_add(self_type
const& b, self_type
const& c)
const
614 return self_type(m_value * b.m_value + c.m_value);
620 template<
typename RETURN_TYPE = self_type>
622 typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
624 multiply_subtract(self_type
const& b, self_type
const& c)
const
626 return self_type(fma(m_value, b.m_value, -c.m_value));
632 template<
typename RETURN_TYPE = self_type>
634 typename std::enable_if<std::numeric_limits<element_type>::is_integer,
636 multiply_subtract(self_type
const& b, self_type
const& c)
const
638 return self_type(m_value * b.m_value - c.m_value);
648 element_type
sum()
const
654 return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(m_value);
664 element_type
max()
const
671 return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(m_value);
681 element_type max_n(
int N)
const
689 auto lane = get_lane();
690 auto value = lane < N ? m_value : ident;
691 return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(value);
701 self_type vmax(self_type a)
const
703 return self_type {RAJA::max<element_type>(m_value, a.m_value)};
713 element_type
min()
const
720 return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(m_value);
730 element_type min_n(
int N)
const
738 auto lane = get_lane();
739 auto value = lane < N ? m_value : ident;
740 return RAJA::hip::impl::warp_allreduce<combiner_t, element_type>(value);
750 self_type vmin(self_type a)
const
752 return self_type {RAJA::min<element_type>(m_value, a.m_value)};
765 static int_vector_type s_segmented_offsets(camp::idx_t segbits,
766 camp::idx_t stride_inner,
767 camp::idx_t stride_outer)
769 int_vector_type result;
771 auto lane = get_lane();
774 auto seg = lane >> segbits;
775 auto i = lane & ((1 << segbits) - 1);
777 result.get_raw_value() = seg * stride_outer + i * stride_inner;
818 self_type segmented_sum_inner(camp::idx_t segbits,
819 camp::idx_t output_segment)
const
823 element_type
x = m_value;
825 for (
int delta = 1; delta < 1 << segbits; delta = delta << 1)
829 element_type
y = hip::impl::shfl_sync(x, get_lane() + delta);
837 result.get_raw_value() = hip::impl::shfl_sync(x, get_lane() << segbits);
841 static constexpr
int log2_warp_size =
RAJA::log2(RAJA_HIP_WAVESIZE);
842 int our_output_segment = get_lane() >> (log2_warp_size - segbits);
843 bool in_output_segment = our_output_segment == output_segment;
844 if (!in_output_segment)
846 result.get_raw_value() = 0;
886 self_type segmented_sum_outer(camp::idx_t segbits,
887 camp::idx_t output_segment)
const
891 element_type
x = m_value;
892 static constexpr
int log2_warp_size =
RAJA::log2(RAJA_HIP_WAVESIZE);
894 for (
int i = 0; i < log2_warp_size - segbits; ++i)
898 int delta = s_num_elem >> (i + 1);
899 element_type
y = hip::impl::shfl_sync(x, get_lane() + delta);
907 int get_from = get_lane() & ((1 << segbits) - 1);
908 result.get_raw_value() = hip::impl::shfl_sync(x, get_from);
910 int mask = (get_lane() >> segbits) == output_segment;
916 result.get_raw_value() = 0;
925 self_type segmented_divide_nm(self_type den,
927 camp::idx_t num_inner,
928 camp::idx_t num_outer)
const
932 auto lane = get_lane();
935 auto seg = lane >> segbits;
936 auto i = lane & ((1 << segbits) - 1);
938 if (seg >= num_outer || i >= num_inner)
944 result.get_raw_value() = m_value / den.get_raw_value();
1000 self_type segmented_broadcast_inner(camp::idx_t segbits,
1001 camp::idx_t input_segment)
const
1005 camp::idx_t mask = (1 << segbits) - 1;
1006 camp::idx_t offset = input_segment << segbits;
1009 camp::idx_t i = (get_lane() & mask) + offset;
1011 result.get_raw_value() = hip::impl::shfl_sync(m_value, i);
1056 self_type segmented_broadcast_outer(camp::idx_t segbits,
1057 camp::idx_t input_segment)
const
1061 camp::idx_t offset = input_segment * (self_type::s_num_elem >> segbits);
1063 camp::idx_t i = (get_lane() >> segbits) + offset;
1065 result.get_raw_value() = hip::impl::shfl_sync(m_value, i);
Header file for RAJA operator definitions.
RAJA header file defining SIMD/SIMT register operations.
Header file containing RAJA intrinsics templates for HIP execution.
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE T log2(T n) noexcept
evaluate log base 2 of n
Definition: math.hpp:40
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
Definition: Operators.hpp:580
Definition: Operators.hpp:559
Definition: reduce.hpp:70