21 #ifndef RAJA_policy_tensor_arch_cuda_cuda_warp_register_HPP
22 #define RAJA_policy_tensor_arch_cuda_cuda_warp_register_HPP
24 #include "RAJA/config.hpp"
26 #if defined(RAJA_CUDA_ACTIVE)
40 template<
typename ELEMENT_TYPE>
41 class Register<ELEMENT_TYPE, cuda_warp_register>
42 :
public internal::expt::RegisterBase<
43 Register<ELEMENT_TYPE, cuda_warp_register>>
47 internal::expt::RegisterBase<Register<ELEMENT_TYPE, cuda_warp_register>>;
49 using register_policy = cuda_warp_register;
50 using self_type = Register<ELEMENT_TYPE, cuda_warp_register>;
51 using element_type = ELEMENT_TYPE;
52 using register_type = ELEMENT_TYPE;
54 using int_vector_type = Register<int64_t, cuda_warp_register>;
61 static constexpr
int s_num_elem = RAJA_CUDA_WARPSIZE;
69 constexpr Register() : base_type(), m_value(0) {}
77 constexpr Register(element_type c) : base_type(), m_value(c) {}
85 constexpr Register(self_type
const& c) : base_type(), m_value(c.m_value) {}
93 self_type& operator=(self_type
const& c)
102 self_type& operator=(element_type c)
114 constexpr
static int get_lane() {
return threadIdx.x; }
119 constexpr element_type
const& get_raw_value()
const {
return m_value; }
124 element_type& get_raw_value() {
return m_value; }
129 static constexpr
bool is_root() {
return get_lane() == 0; }
138 self_type& load_packed(element_type
const* ptr)
141 auto lane = get_lane();
156 self_type& load_packed_n(element_type
const* ptr,
int N)
158 auto lane = get_lane();
165 m_value = element_type(0);
177 self_type& load_strided(element_type
const* ptr,
int stride)
180 auto lane = get_lane();
182 m_value = ptr[stride * lane];
195 self_type& load_strided_n(element_type
const* ptr,
int stride,
int N)
197 auto lane = get_lane();
201 m_value = ptr[stride * lane];
205 m_value = element_type(0);
222 self_type& gather(element_type
const* ptr, int_vector_type offsets)
225 m_value = ptr[offsets.get_raw_value()];
242 self_type& gather_n(element_type
const* ptr,
243 int_vector_type offsets,
248 m_value = ptr[offsets.get_raw_value()];
252 m_value = element_type(0);
270 self_type& segmented_load(element_type
const* ptr,
272 camp::idx_t stride_inner,
273 camp::idx_t stride_outer)
275 auto lane = get_lane();
278 auto seg = lane >> segbits;
279 auto i = lane & ((1 << segbits) - 1);
281 m_value = ptr[seg * stride_outer + i * stride_inner];
296 self_type& segmented_load_nm(element_type
const* ptr,
298 camp::idx_t stride_inner,
299 camp::idx_t stride_outer,
300 camp::idx_t num_inner,
301 camp::idx_t num_outer)
303 auto lane = get_lane();
306 auto seg = lane >> segbits;
307 auto i = lane & ((1 << segbits) - 1);
309 if (seg >= num_outer || i >= num_inner)
311 m_value = element_type(0);
315 m_value = ptr[seg * stride_outer + i * stride_inner];
328 self_type
const& store_packed(element_type* ptr)
const
331 auto lane = get_lane();
345 self_type
const& store_packed_n(element_type* ptr,
int N)
const
348 auto lane = get_lane();
364 self_type
const& store_strided(element_type* ptr,
int stride)
const
367 auto lane = get_lane();
369 ptr[lane * stride] = m_value;
381 self_type
const& store_strided_n(element_type* ptr,
int stride,
int N)
const
384 auto lane = get_lane();
388 ptr[lane * stride] = m_value;
402 template<
typename T2>
403 RAJA_DEVICE RAJA_INLINE self_type
const& scatter(element_type* ptr,
404 T2
const& offsets)
const
407 ptr[offsets.get_raw_value()] = m_value;
421 template<
typename T2>
422 RAJA_DEVICE RAJA_INLINE self_type
const& scatter_n(element_type* ptr,
428 ptr[offsets.get_raw_value()] = m_value;
442 self_type
const& segmented_store(element_type* ptr,
444 camp::idx_t stride_inner,
445 camp::idx_t stride_outer)
const
447 auto lane = get_lane();
450 auto seg = lane >> segbits;
451 auto i = lane & ((1 << segbits) - 1);
453 ptr[seg * stride_outer + i * stride_inner] = m_value;
466 self_type
const& segmented_store_nm(element_type* ptr,
468 camp::idx_t stride_inner,
469 camp::idx_t stride_outer,
470 camp::idx_t num_inner,
471 camp::idx_t num_outer)
const
473 auto lane = get_lane();
476 auto seg = lane >> segbits;
477 auto i = lane & ((1 << segbits) - 1);
479 if (seg >= num_outer || i >= num_inner)
485 ptr[seg * stride_outer + i * stride_inner] = m_value;
498 return __shfl_sync(0xffffffff, m_value, i);
509 self_type& set(element_type value,
int i)
511 auto lane = get_lane();
522 self_type& broadcast(element_type
const& a)
534 self_type get_and_broadcast(
int i)
const
537 x.m_value = __shfl_sync(0xffffffff, m_value, i);
544 self_type& copy(self_type
const& src)
546 m_value = src.m_value;
553 self_type add(self_type
const& b)
const
555 return self_type(m_value + b.m_value);
561 self_type subtract(self_type
const& b)
const
563 return self_type(m_value - b.m_value);
569 self_type multiply(self_type
const& b)
const
571 return self_type(m_value * b.m_value);
577 self_type divide(self_type
const& b)
const
579 return self_type(m_value / b.m_value);
585 self_type divide_n(self_type
const& b,
int N)
const
587 return get_lane() < N ? self_type(m_value / b.m_value)
588 : self_type(element_type(0));
594 template<
typename RETURN_TYPE = self_type>
596 typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
598 multiply_add(self_type
const& b, self_type
const& c)
const
600 return self_type(fma(m_value, b.m_value, c.m_value));
606 template<
typename RETURN_TYPE = self_type>
608 typename std::enable_if<std::numeric_limits<element_type>::is_integer,
610 multiply_add(self_type
const& b, self_type
const& c)
const
612 return self_type(m_value * b.m_value + c.m_value);
618 template<
typename RETURN_TYPE = self_type>
620 typename std::enable_if<!std::numeric_limits<element_type>::is_integer,
622 multiply_subtract(self_type
const& b, self_type
const& c)
const
624 return self_type(fma(m_value, b.m_value, -c.m_value));
630 template<
typename RETURN_TYPE = self_type>
632 typename std::enable_if<std::numeric_limits<element_type>::is_integer,
634 multiply_subtract(self_type
const& b, self_type
const& c)
const
636 return self_type(m_value * b.m_value - c.m_value);
646 element_type
sum()
const
652 return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(m_value);
662 element_type
max()
const
669 return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(m_value);
679 element_type max_n(
int N)
const
687 auto lane = get_lane();
688 auto value = lane < N ? m_value : ident;
689 return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(value);
699 self_type vmax(self_type a)
const
701 return self_type {RAJA::max<element_type>(m_value, a.m_value)};
711 element_type
min()
const
718 return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(m_value);
728 element_type min_n(
int N)
const
736 auto lane = get_lane();
737 auto value = lane < N ? m_value : ident;
738 return RAJA::cuda::impl::warp_allreduce<combiner_t, element_type>(value);
748 self_type vmin(self_type a)
const
750 return self_type {RAJA::min<element_type>(m_value, a.m_value)};
763 static int_vector_type s_segmented_offsets(camp::idx_t segbits,
764 camp::idx_t stride_inner,
765 camp::idx_t stride_outer)
767 int_vector_type result;
769 auto lane = get_lane();
772 auto seg = lane >> segbits;
773 auto i = lane & ((1 << segbits) - 1);
775 result.get_raw_value() = seg * stride_outer + i * stride_inner;
816 self_type segmented_sum_inner(camp::idx_t segbits,
817 camp::idx_t output_segment)
const
821 element_type
x = m_value;
823 for (
int delta = 1; delta < 1 << segbits; delta = delta << 1)
827 element_type
y = __shfl_sync(0xffffffff, x, get_lane() + delta);
835 result.get_raw_value() = __shfl_sync(0xffffffff, x, get_lane() << segbits);
839 static constexpr
int log2_warp_size =
RAJA::log2(RAJA_CUDA_WARPSIZE);
840 int our_output_segment = get_lane() >> (log2_warp_size - segbits);
841 bool in_output_segment = our_output_segment == output_segment;
842 if (!in_output_segment)
844 result.get_raw_value() = 0;
884 self_type segmented_sum_outer(camp::idx_t segbits,
885 camp::idx_t output_segment)
const
889 element_type
x = m_value;
890 static constexpr
int log2_warp_size =
RAJA::log2(RAJA_CUDA_WARPSIZE);
892 for (
int i = 0; i < log2_warp_size - segbits; ++i)
896 int delta = s_num_elem >> (i + 1);
897 element_type
y = __shfl_sync(0xffffffff, x, get_lane() + delta);
905 int get_from = get_lane() & ((1 << segbits) - 1);
906 result.get_raw_value() = __shfl_sync(0xffffffff, x, get_from);
908 int mask = (get_lane() >> segbits) == output_segment;
914 result.get_raw_value() = 0;
923 self_type segmented_divide_nm(self_type den,
925 camp::idx_t num_inner,
926 camp::idx_t num_outer)
const
930 auto lane = get_lane();
933 auto seg = lane >> segbits;
934 auto i = lane & ((1 << segbits) - 1);
936 if (seg >= num_outer || i >= num_inner)
942 result.get_raw_value() = m_value / den.get_raw_value();
998 self_type segmented_broadcast_inner(camp::idx_t segbits,
999 camp::idx_t input_segment)
const
1003 camp::idx_t mask = (1 << segbits) - 1;
1004 camp::idx_t offset = input_segment << segbits;
1007 camp::idx_t i = (get_lane() & mask) + offset;
1009 result.get_raw_value() = __shfl_sync(0xffffffff, m_value, i);
1054 self_type segmented_broadcast_outer(camp::idx_t segbits,
1055 camp::idx_t input_segment)
const
1059 camp::idx_t offset = input_segment * (self_type::s_num_elem >> segbits);
1061 camp::idx_t i = (get_lane() >> segbits) + offset;
1063 result.get_raw_value() = __shfl_sync(0xffffffff, m_value, i);
Header file for RAJA operator definitions.
RAJA header file defining SIMD/SIMT register operations.
Header file containing RAJA intrinsics templates for CUDA execution.
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result sum(Args... args)
Definition: foldl.hpp:143
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
RAJA_HOST_DEVICE constexpr RAJA_INLINE T log2(T n) noexcept
evaluate log base 2 of n
Definition: math.hpp:40
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
Definition: Operators.hpp:580
Definition: Operators.hpp:559
Definition: reduce.hpp:70