20 #ifndef RAJA_pattern_tensor_RegisterBase_HPP
21 #define RAJA_pattern_tensor_RegisterBase_HPP
23 #include "RAJA/config.hpp"
27 #include "camp/camp.hpp"
38 template<
typename T,
typename REGISTER_POLICY>
56 typename std::enable_if<std::is_arithmetic<LEFT>::value,
bool>::type =
true,
57 typename std::enable_if<std::is_base_of<RegisterConcreteBase, RIGHT>::value,
61 return RIGHT(lhs).add(rhs);
71 typename std::enable_if<std::is_arithmetic<LEFT>::value,
bool>::type =
true,
72 typename std::enable_if<std::is_base_of<RegisterConcreteBase, RIGHT>::value,
76 return RIGHT(lhs).subtract(rhs);
86 typename std::enable_if<std::is_arithmetic<LEFT>::value,
bool>::type =
true,
87 typename std::enable_if<std::is_base_of<RegisterConcreteBase, RIGHT>::value,
91 return rhs.scale(lhs);
101 typename std::enable_if<std::is_arithmetic<LEFT>::value,
bool>::type =
true,
102 typename std::enable_if<std::is_base_of<RegisterConcreteBase, RIGHT>::value,
106 return RIGHT(lhs).divide(rhs);
116 template<
typename Derived>
119 template<
typename T,
typename REGISTER_POLICY>
143 constexpr self_type
const* getThis()
const
145 return static_cast<self_type const*
>(
this);
152 static constexpr
bool is_root() {
return true; }
184 for (camp::idx_t i = 0; i < N; ++i)
201 x.broadcast(getThis()->
get(i));
214 template<
typename T2>
219 #ifdef RAJA_ENABLE_VECTOR_STATS
220 RAJA::tensor_stats::num_vector_load_strided_n++;
222 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
224 getThis()->set(ptr[offsets.get(i)], i);
238 template<
typename T2>
244 #ifdef RAJA_ENABLE_VECTOR_STATS
245 RAJA::tensor_stats::num_vector_load_strided_n++;
247 for (camp::idx_t i = 0; i < N; ++i)
249 getThis()->set(ptr[offsets.get(i)], i);
269 camp::idx_t stride_inner,
270 camp::idx_t stride_outer)
272 getThis()->gather(ptr, self_type::s_segmented_offsets(segbits, stride_inner,
289 camp::idx_t stride_inner,
290 camp::idx_t stride_outer,
291 camp::idx_t num_inner,
292 camp::idx_t num_outer)
295 camp::idx_t num_segments = self_type::s_num_elem >> segbits;
296 camp::idx_t seg_size = 1 << segbits;
298 camp::idx_t lane = 0;
299 for (camp::idx_t seg = 0; seg < num_segments; ++seg)
301 for (camp::idx_t i = 0; i < seg_size; ++i)
304 if (seg >= num_outer || i >= num_inner)
311 camp::idx_t offset = seg * stride_outer + i * stride_inner;
315 getThis()->set(value, lane);
334 template<
typename T2>
339 #ifdef RAJA_ENABLE_VECTOR_STATS
340 RAJA::tensor_stats::num_vector_load_strided_n++;
342 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
344 ptr[offsets.get(i)] = getThis()->get(i);
358 template<
typename T2>
364 #ifdef RAJA_ENABLE_VECTOR_STATS
365 RAJA::tensor_stats::num_vector_load_strided_n++;
367 for (camp::idx_t i = 0; i < N; ++i)
369 ptr[offsets.get(i)] = getThis()->get(i);
389 camp::idx_t stride_inner,
390 camp::idx_t stride_outer)
const
392 getThis()->scatter(ptr, self_type::s_segmented_offsets(
393 segbits, stride_inner, stride_outer));
409 camp::idx_t stride_inner,
410 camp::idx_t stride_outer,
411 camp::idx_t num_inner,
412 camp::idx_t num_outer)
const
415 camp::idx_t num_segments = self_type::s_num_elem >> segbits;
416 camp::idx_t seg_size = 1 << segbits;
418 camp::idx_t lane = 0;
419 for (camp::idx_t seg = 0; seg < num_segments; ++seg)
421 for (camp::idx_t i = 0; i < seg_size; ++i)
424 if (!(seg >= num_outer || i >= num_inner))
427 camp::idx_t offset = seg * stride_outer + i * stride_inner;
429 ptr[offset] = getThis()->get(lane);
449 getThis()->broadcast(value);
458 template<
typename T2>
462 getThis()->broadcast(value.get(0));
503 *getThis() = getThis()->add(
x);
534 *getThis() = getThis()->add(
x);
559 return getThis()->subtract(
x);
573 *getThis() = getThis()->subtract(
x);
588 return getThis()->subtract(
x);
602 *getThis() = getThis()->subtract(
x);
611 template<
typename RHS>
614 return getThis()->multiply(rhs);
622 template<
typename RHS>
625 *getThis() = getThis()->multiply(rhs);
651 *getThis() = getThis()->divide(
x);
666 return getThis()->divide(
x);
680 *getThis() = getThis()->divide(
x);
697 for (camp::idx_t i = 0; i < n; ++i)
699 q.set(getThis()->
get(i) / b.get(i), i);
717 for (camp::idx_t i = 0; i < n; ++i)
719 q.set(getThis()->
get(i) / b, i);
735 return getThis()->multiply(
x).sum();
771 return getThis()->multiply_add(b, -c);
783 return getThis()->multiply(
self_type(c));
824 auto const&
x = *getThis();
828 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
832 camp::idx_t xy_select = (i >> lvl) & 0x1;
835 z.set(xy_select == 0 ?
x.get(i) :
y.get(i - (1 << lvl)), i);
861 auto const&
x = *getThis();
865 camp::idx_t i0 = 1 << lvl;
867 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
871 camp::idx_t xy_select = (i >> lvl) & 0x1;
873 z.set(xy_select == 0 ?
x.get(i0 + i) :
y.get(i0 + i - (1 << lvl)), i);
888 camp::idx_t stride_inner,
889 camp::idx_t stride_outer)
893 camp::idx_t num_segments = self_type::s_num_elem >> segbits;
894 camp::idx_t seg_size = 1 << segbits;
896 camp::idx_t lane = 0;
897 for (camp::idx_t seg = 0; seg < num_segments; ++seg)
899 for (camp::idx_t i = 0; i < seg_size; ++i)
901 result.set(seg * stride_outer + i * stride_inner, lane);
944 camp::idx_t output_segment)
const
950 int output_offset = output_segment * self_type::s_num_elem >> segbits;
952 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
955 getThis()->get(i) + result.get((i >> segbits) + output_offset);
956 result.set(value, (i >> segbits) + output_offset);
1000 camp::idx_t output_segment)
const
1006 int output_offset = output_segment * (1 << segbits);
1008 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
1010 camp::idx_t output_i = output_offset + (i & ((1 << segbits) - 1));
1011 auto value = getThis()->get(i) + result.get(output_i);
1012 result.set(value, output_i);
1020 camp::idx_t segbits,
1021 camp::idx_t num_inner,
1022 camp::idx_t num_outer)
const
1026 camp::idx_t num_segments = self_type::s_num_elem >> segbits;
1027 camp::idx_t seg_size = 1 << segbits;
1029 camp::idx_t lane = 0;
1030 for (camp::idx_t seg = 0; seg < num_segments; ++seg)
1032 for (camp::idx_t i = 0; i < seg_size; ++i)
1035 if (seg >= num_outer || i >= num_inner)
1042 element_type div = getThis()->get(lane) / den.get(lane);
1044 result.set(div, lane);
1089 camp::idx_t output_segment,
1092 return getThis()->multiply(
x).segmented_sum_inner(segbits, output_segment);
1144 camp::idx_t input_segment)
const
1148 camp::idx_t mask = (1 << segbits) - 1;
1149 camp::idx_t offset = input_segment << segbits;
1153 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
1156 auto off = (i & mask) + offset;
1158 result.set(getThis()->get(off), i);
1202 camp::idx_t input_segment)
const
1206 camp::idx_t offset = input_segment * (self_type::s_num_elem >> segbits);
1210 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
1213 auto off = (i >> segbits) + offset;
1215 result.set(getThis()->get(off), i);
1229 std::string s =
"Register(" + std::to_string(self_type::s_num_elem) +
")[ ";
1232 for (camp::idx_t i = 0; i < self_type::s_num_elem; ++i)
1234 s += std::to_string(getThis()->
get(i)) +
" ";
RAJA header file defining a bit masking operator.
RAJA header file defining SIMD/SIMT register operations.
RAJA header file defining SIMD/SIMT register operations.
Header file containing RAJA simd policy definitions.
Definition: RegisterBase.hpp:39
RAJA_INLINE RAJA_HOST_DEVICE self_type transpose_shuffle_right(int lvl, self_type const &y) const
Definition: RegisterBase.hpp:859
RAJA_HOST_DEVICE RAJA_INLINE self_type const & segmented_store_nm(element_type *ptr, camp::idx_t segbits, camp::idx_t stride_inner, camp::idx_t stride_outer, camp::idx_t num_inner, camp::idx_t num_outer) const
Generic segmented load operation used for loading sub-matrices from larger arrays where we load parti...
Definition: RegisterBase.hpp:407
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE element_type dot(self_type const &x) const
Dot product of two registers.
Definition: RegisterBase.hpp:733
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE static RAJA_INLINE self_type s_broadcast_n(element_type const &value, camp::idx_t N)
Broadcast scalar value to first N register elements.
Definition: RegisterBase.hpp:181
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE element_type min_n(camp::idx_t N) const
Definition: RegisterBase.hpp:793
RAJA_HOST_DEVICE RAJA_INLINE self_type & operator*=(RHS const &rhs)
Multiply a register with this register.
Definition: RegisterBase.hpp:623
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator-=(self_type const &x)
Subtract a register from this register.
Definition: RegisterBase.hpp:571
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & segmented_store(element_type *ptr, camp::idx_t segbits, camp::idx_t stride_inner, camp::idx_t stride_outer) const
Generic segmented load operation used for loading sub-matrices from larger arrays.
Definition: RegisterBase.hpp:387
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(RAJA::expt::Register< T2, RAJA::expt::scalar_register > const &value)
Set entire register to a single scalar value.
Definition: RegisterBase.hpp:459
RAJA_INLINE self_type segmented_broadcast_inner(camp::idx_t segbits, camp::idx_t input_segment) const
Definition: RegisterBase.hpp:1143
RAJA_INLINE self_type segmented_broadcast_outer(camp::idx_t segbits, camp::idx_t input_segment) const
Definition: RegisterBase.hpp:1201
RAJA_HOST_DEVICE RAJA_INLINE self_type const & scatter_n(element_type *ptr, RAJA::expt::Register< T2, REGISTER_POLICY > const &offsets, camp::idx_t N) const
Generic scatter operation for n-length subvector.
Definition: RegisterBase.hpp:359
RAJA_HOST_DEVICE RAJA_INLINE self_type & gather(element_type const *ptr, RAJA::expt::Register< T2, REGISTER_POLICY > offsets)
Generic gather operation for full vector.
Definition: RegisterBase.hpp:215
RAJA_HOST_DEVICE static constexpr RAJA_INLINE bool is_root()
Definition: RegisterBase.hpp:152
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator+=(element_type x)
Add a scalar to this register.
Definition: RegisterBase.hpp:532
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type multiply_subtract(self_type const &b, self_type const &c) const
Fused multiply subtract: fms(b, c) = (*this)*b-c.
Definition: RegisterBase.hpp:769
RAJA_INLINE constexpr RAJA_HOST_DEVICE RegisterBase(self_type const &)
Definition: RegisterBase.hpp:172
typename RegisterTraits< REGISTER_POLICY, T >::int_element_type int_element_type
Definition: RegisterBase.hpp:130
camp::decay< T > element_type
Definition: RegisterBase.hpp:125
RAJA_INLINE RAJA_HOST_DEVICE self_type transpose_shuffle_left(camp::idx_t lvl, self_type const &y) const
Definition: RegisterBase.hpp:822
RAJA_HOST_DEVICE RAJA_INLINE self_type operator*(RHS const &rhs) const
Multiply two register registers, element wise.
Definition: RegisterBase.hpp:612
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator+(element_type const &x) const
Add scalar to this register.
Definition: RegisterBase.hpp:518
RAJA_HOST_DEVICE RAJA_INLINE ~RegisterBase()
Definition: RegisterBase.hpp:162
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type divide_n(self_type const &b, camp::idx_t n) const
Divide n elements of this register by another register.
Definition: RegisterBase.hpp:694
RAJA_HOST_DEVICE RAJA_INLINE self_type & gather_n(element_type const *ptr, RAJA::expt::Register< T2, REGISTER_POLICY > const &offsets, camp::idx_t N)
Generic gather operation for n-length subvector.
Definition: RegisterBase.hpp:239
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator-=(element_type const &x)
Subtract a scalar from this register.
Definition: RegisterBase.hpp:600
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type segmented_dot(camp::idx_t segbits, camp::idx_t output_segment, self_type const &x) const
Definition: RegisterBase.hpp:1088
RAJA_HOST_DEVICE constexpr RAJA_INLINE RegisterBase()
Definition: RegisterBase.hpp:157
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE element_type max_n(camp::idx_t N) const
Definition: RegisterBase.hpp:802
RAJA_HOST_DEVICE constexpr RAJA_INLINE RegisterBase(RegisterBase const &)
Definition: RegisterBase.hpp:167
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type operator/(element_type const &x) const
Divide by a scalar, element wise.
Definition: RegisterBase.hpp:664
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator/=(self_type const &x)
Divide this register by another register.
Definition: RegisterBase.hpp:649
RAJA_INLINE self_type segmented_sum_inner(camp::idx_t segbits, camp::idx_t output_segment) const
Definition: RegisterBase.hpp:943
RAJA_INLINE self_type segmented_divide_nm(self_type den, camp::idx_t segbits, camp::idx_t num_inner, camp::idx_t num_outer) const
Definition: RegisterBase.hpp:1019
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type scale(element_type c) const
Definition: RegisterBase.hpp:781
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type operator/(self_type const &x) const
Divide two register registers, element wise.
Definition: RegisterBase.hpp:638
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator-() const
Negate the value of this register.
Definition: RegisterBase.hpp:546
RAJA_INLINE std::string to_string() const
Converts to vector to a string.
Definition: RegisterBase.hpp:1227
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type get_and_broadcast(int i) const
Extracts a scalar value and broadcasts to a new register.
Definition: RegisterBase.hpp:198
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(element_type value)
Set entire register to a single scalar value.
Definition: RegisterBase.hpp:447
static RAJA_INLINE int_vector_type s_segmented_offsets(camp::idx_t segbits, camp::idx_t stride_inner, camp::idx_t stride_outer)
Definition: RegisterBase.hpp:887
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & segmented_load(element_type const *ptr, camp::idx_t segbits, camp::idx_t stride_inner, camp::idx_t stride_outer)
Generic segmented load operation used for loading sub-matrices from larger arrays.
Definition: RegisterBase.hpp:267
RAJA_HOST_DEVICE RAJA_INLINE self_type const & scatter(element_type *ptr, RAJA::expt::Register< T2, REGISTER_POLICY > const &offsets) const
Generic scatter operation for full vector.
Definition: RegisterBase.hpp:335
camp::idx_t index_type
Definition: RegisterBase.hpp:127
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type multiply_add(self_type const &b, self_type const &c) const
Fused multiply add: fma(b, c) = (*this)*b+c.
Definition: RegisterBase.hpp:751
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator-(element_type const &x) const
Subtract scalar from this register.
Definition: RegisterBase.hpp:586
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator/=(element_type const &x)
Divide this register by another register.
Definition: RegisterBase.hpp:678
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator-(self_type const &x) const
Subtract two register registers.
Definition: RegisterBase.hpp:557
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(self_type const &x)
Assign one register to another.
Definition: RegisterBase.hpp:475
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator+=(self_type const &x)
Add a register to this register.
Definition: RegisterBase.hpp:501
RAJA_INLINE self_type segmented_sum_outer(camp::idx_t segbits, camp::idx_t output_segment) const
Definition: RegisterBase.hpp:999
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type divide_n(element_type const &b, camp::idx_t n) const
Divide n elements of this register by a scalar.
Definition: RegisterBase.hpp:714
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator+(self_type const &x) const
Add two registers.
Definition: RegisterBase.hpp:490
RAJA_HOST_DEVICE RAJA_INLINE self_type & segmented_load_nm(element_type const *ptr, camp::idx_t segbits, camp::idx_t stride_inner, camp::idx_t stride_outer, camp::idx_t num_inner, camp::idx_t num_outer)
Generic segmented load operation used for loading sub-matrices from larger arrays where we load parti...
Definition: RegisterBase.hpp:287
Definition: RegisterBase.hpp:117
Definition: RegisterBase.hpp:47
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_SUPPRESS_HD_WARN
Definition: macros.hpp:68
RAJA_INLINE RAJA_HOST_DEVICE RIGHT operator/(LEFT const &lhs, RIGHT const &rhs)
Definition: RegisterBase.hpp:104
RAJA_INLINE RAJA_HOST_DEVICE RIGHT operator+(LEFT const &lhs, RIGHT const &rhs)
Definition: RegisterBase.hpp:59
RAJA_INLINE RAJA_HOST_DEVICE RIGHT operator-(LEFT const &lhs, RIGHT const &rhs)
Definition: RegisterBase.hpp:74
RAJA_INLINE RAJA_HOST_DEVICE RIGHT operator*(LEFT const &lhs, RIGHT const &rhs)
Definition: RegisterBase.hpp:89
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56