20 #ifndef RAJA_pattern_tensor_TensorRegisterBase_HPP
21 #define RAJA_pattern_tensor_TensorRegisterBase_HPP
23 #include "RAJA/config.hpp"
27 #include "camp/camp.hpp"
41 class TensorExpressionConcreteBase;
44 template<
typename TENSOR, camp::
idx_t DIM>
47 static constexpr camp::idx_t
value = TENSOR::s_dim_size(DIM);
57 template<
typename LHS,
typename RHS>
69 return lhs.multiply(rhs);
73 template<
typename REF_TYPE>
80 template<
typename RHS>
89 template<camp::
idx_t N, camp::
idx_t D>
92 static constexpr camp::idx_t
value = (N % D) > 0 ? (1 + N / D) : (N / D);
104 template<
typename Derived>
107 template<
typename REGISTER_POLICY,
110 typename camp::idx_t... SIZES>
113 TensorRegister<REGISTER_POLICY, T, LAYOUT, camp::idx_seq<SIZES...>>>
118 TensorRegister<REGISTER_POLICY, T, LAYOUT, camp::idx_seq<SIZES...>>;
121 static constexpr camp::idx_t s_num_dims =
sizeof...(SIZES);
123 static constexpr camp::idx_t s_num_registers =
142 constexpr self_type
const* getThis()
const
144 return static_cast<self_type const*
>(
this);
174 template<
typename RHS,
175 typename std::enable_if<
176 std::is_base_of<ET::TensorExpressionConcreteBase, RHS>::value,
181 *
this = rhs.eval(self_type::s_get_default_tile());
184 template<
typename... REGS>
187 : m_registers {reg0, regs...}
189 static_assert(1 +
sizeof...(REGS) == s_num_registers,
190 "Incompatible number of registers");
196 static constexpr
bool is_root() {
return register_type::is_root(); }
198 template<
typename REF_TYPE>
206 template<
typename REF_TYPE>
225 return (dim == 0) ? self_type::s_num_elem : 0;
238 camp::int_seq<int, int(SIZES * 0)...>,
239 camp::int_seq<int, int(SIZES)...>>
243 camp::int_seq<int, int(SIZES * 0)...>,
244 camp::int_seq<int, int(SIZES)...>>();
255 constexpr
bool sink()
const {
return false; }
265 for (camp::idx_t i = 0; i < s_num_registers; ++i)
267 m_registers[i] = c.vec(i);
280 for (camp::idx_t i = 0; i < s_num_registers; ++i)
297 for (camp::idx_t i = 0; i < s_num_registers; ++i)
299 m_registers[i].broadcast(v);
313 for (camp::idx_t i = 0; i < N; ++i)
315 getThis()->set(value, i);
330 x.broadcast(getThis()->
get(i));
340 for (camp::idx_t i = 0; i < s_num_registers; ++i)
342 result.vec(i) = m_registers[i].add(mat.vec(i));
353 for (camp::idx_t i = 0; i < s_num_registers; ++i)
355 result.vec(i) = m_registers[i].subtract(mat.vec(i));
369 for (camp::idx_t i = 0; i < s_num_registers; ++i)
371 result.vec(i) = m_registers[i].multiply(
x.vec(i));
385 for (camp::idx_t i = 0; i < s_num_registers; ++i)
387 result.vec(i) = m_registers[i].multiply_add(
x.vec(i), add.vec(i));
398 for (camp::idx_t reg = 0; reg < s_num_registers; ++reg)
400 result.vec(reg) = m_registers[reg].divide(mat.vec(reg));
417 for (camp::idx_t reg = 0; reg < s_num_registers; ++reg)
419 result += m_registers[reg].multiply(
x.vec(reg)).sum();
435 getThis()->broadcast(value);
444 template<
typename T2>
449 camp::idx_seq<>>
const& value)
451 getThis()->broadcast(value.get(0));
492 *getThis() = getThis()->add(
x);
518 *getThis() = getThis()->add(
x);
543 return getThis()->subtract(
x);
557 *getThis() = getThis()->subtract(
x);
572 return getThis()->subtract(
x);
586 *getThis() = getThis()->subtract(
x);
595 template<
typename RHS>
608 template<
typename RHS>
638 *getThis() = getThis()->divide(
x);
653 return getThis()->divide(
x);
667 *getThis() = getThis()->divide(
x);
680 for (camp::idx_t i = 0; i < s_num_registers; ++i)
682 result.vec(i) = m_registers[i].vmin(
x.vec(i));
696 for (camp::idx_t i = 0; i < s_num_registers; ++i)
698 result.vec(i) = m_registers[i].vmax(
x.vec(i));
723 return m_registers[reg];
741 return getThis()->multiply_add(b, -c);
753 return getThis()->multiply(
self_type(c));
765 *getThis() = getThis()->add(
x);
778 *getThis() = getThis()->subtract(
x);
791 *getThis() = getThis()->multiply(
x);
804 *getThis() = getThis()->multiply_add(
x,
y);
817 *getThis() = getThis()->multiply_subtract(
x,
y);
830 *getThis() = getThis()->divide(
x);
843 *getThis() = getThis()->scale(
x);
RAJA header file defining SIMD/SIMT register operations.
RAJA header file defining SIMD/SIMT register operations.
Definition: RegisterBase.hpp:39
Definition: TensorRegister.hpp:46
RAJA_HOST_DEVICE RAJA_INLINE self_type & operator*=(RHS const &rhs)
Multiply a vector with this vector.
Definition: TensorRegisterBase.hpp:609
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type multiply_subtract(self_type const &b, self_type const &c) const
Fused multiply subtract: fms(b, c) = (*this)*b-c.
Definition: TensorRegisterBase.hpp:739
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_divide(self_type x)
Definition: TensorRegisterBase.hpp:828
RAJA_INLINE RAJA_HOST_DEVICE TensorRegisterBase(self_type const &c)
Definition: TensorRegisterBase.hpp:164
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator/=(element_type const &x)
Divide this vector by another vector.
Definition: TensorRegisterBase.hpp:665
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_multiply_subtract(self_type x, self_type y)
Definition: TensorRegisterBase.hpp:815
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_scale(element_type x)
Definition: TensorRegisterBase.hpp:841
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & broadcast_n(element_type const &value, camp::idx_t N)
Broadcast scalar value to first N register elements.
Definition: TensorRegisterBase.hpp:311
camp::decay< T > element_type
Definition: TensorRegisterBase.hpp:119
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_multiply(self_type x)
Definition: TensorRegisterBase.hpp:789
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type scale(element_type c) const
Definition: TensorRegisterBase.hpp:751
RAJA_HOST_DEVICE RAJA_INLINE ~TensorRegisterBase()
Definition: TensorRegisterBase.hpp:169
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & operator=(RAJA::expt::TensorRegister< RAJA::expt::scalar_register, T2, RAJA::expt::ScalarLayout, camp::idx_seq<>> const &value)
Set entire vector to a single scalar value.
Definition: TensorRegisterBase.hpp:445
RAJA_HOST_DEVICE RAJA_INLINE self_type & copy(self_type const &c)
Definition: TensorRegisterBase.hpp:263
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_add(self_type x)
Definition: TensorRegisterBase.hpp:763
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator/=(self_type const &x)
Divide this vector by another vector.
Definition: TensorRegisterBase.hpp:636
RAJA_HOST_DEVICE RAJA_INLINE self_type vmin(self_type x) const
Returns element wise minimum value tensor.
Definition: TensorRegisterBase.hpp:677
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type operator/(element_type const &x) const
Divide by a scalar, element wise.
Definition: TensorRegisterBase.hpp:651
RAJA_HOST_DEVICE RAJA_INLINE self_type divide(self_type const &mat) const
Definition: TensorRegisterBase.hpp:395
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator+(element_type const &x) const
Add vector to a scalar.
Definition: TensorRegisterBase.hpp:505
RAJA_INLINE RAJA_HOST_DEVICE TensorRegisterBase(RHS const &rhs)
Definition: TensorRegisterBase.hpp:178
RAJA_HOST_DEVICE RAJA_INLINE self_type vmax(self_type x) const
Returns element wise maximum value tensor.
Definition: TensorRegisterBase.hpp:693
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type operator/(self_type const &x) const
Divide two vector registers, element wise.
Definition: TensorRegisterBase.hpp:625
RAJA_HOST_DEVICE constexpr RAJA_INLINE TensorRegisterBase()
Definition: TensorRegisterBase.hpp:154
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator-=(self_type const &x)
Subtract a vector from this vector.
Definition: TensorRegisterBase.hpp:555
RAJA_HOST_DEVICE RAJA_INLINE TensorDefaultOperation< self_type, RHS >::multiply_type operator*(RHS const &rhs) const
Multiply two vector registers, element wise.
Definition: TensorRegisterBase.hpp:598
REGISTER_POLICY register_policy
Definition: TensorRegisterBase.hpp:131
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator+=(self_type const &x)
Add a vector to this vector.
Definition: TensorRegisterBase.hpp:490
RAJA_HOST_DEVICE RAJA_INLINE self_type add(self_type const &mat) const
Definition: TensorRegisterBase.hpp:337
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & operator=(element_type value)
Set entire vector to a single scalar value.
Definition: TensorRegisterBase.hpp:433
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & operator=(self_type const &x)
Assign one register to antoher.
Definition: TensorRegisterBase.hpp:464
RAJA_HOST_DEVICE constexpr RAJA_INLINE register_type const & vec(int i) const
Definition: TensorRegisterBase.hpp:711
RAJA_HOST_DEVICE RAJA_INLINE self_type & broadcast(element_type v)
Definition: TensorRegisterBase.hpp:295
RAJA_HOST_DEVICE RAJA_INLINE register_type & vec(int i)
Definition: TensorRegisterBase.hpp:706
RAJA_HOST_DEVICE RAJA_INLINE self_type & clear()
Definition: TensorRegisterBase.hpp:278
RAJA_HOST_DEVICE static constexpr RAJA_INLINE StaticTensorTile< int, TENSOR_FULL, camp::int_seq< int, int(SIZES *0)... >, camp::int_seq< int, int(SIZES)... > > s_get_default_tile()
Definition: TensorRegisterBase.hpp:240
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type get_and_broadcast(int i) const
Extracts a scalar value and broadcasts to a new register.
Definition: TensorRegisterBase.hpp:327
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_multiply_add(self_type x, self_type y)
Definition: TensorRegisterBase.hpp:802
RAJA_HOST_DEVICE RAJA_INLINE register_type & get_register(int reg)
Definition: TensorRegisterBase.hpp:716
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator-() const
Negate the value of this vector.
Definition: TensorRegisterBase.hpp:530
RAJA_HOST_DEVICE constexpr RAJA_INLINE bool sink() const
convenience routine to allow Vector classes to use camp::sink() across a variety of register types,...
Definition: TensorRegisterBase.hpp:255
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator-(element_type const &x) const
Subtract scalar from this register.
Definition: TensorRegisterBase.hpp:570
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator+(self_type const &x) const
Add two vector registers.
Definition: TensorRegisterBase.hpp:479
RAJA_HOST_DEVICE RAJA_INLINE TensorRegisterBase(element_type c)
Definition: TensorRegisterBase.hpp:159
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE self_type & inplace_subtract(self_type x)
Definition: TensorRegisterBase.hpp:776
RAJA_HOST_DEVICE static constexpr RAJA_INLINE int s_dim_elem(int dim)
Definition: TensorRegisterBase.hpp:223
RAJA_HOST_DEVICE RAJA_INLINE self_type subtract(self_type const &mat) const
Definition: TensorRegisterBase.hpp:350
RAJA_HOST_DEVICE constexpr RAJA_INLINE register_type const & get_register(int reg) const
Definition: TensorRegisterBase.hpp:721
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator-=(element_type const &x)
Subtract a scalar from this vector.
Definition: TensorRegisterBase.hpp:584
RAJA_INLINE RAJA_HOST_DEVICE element_type dot(self_type const &x) const
Dot product of two vectors.
Definition: TensorRegisterBase.hpp:413
RAJA_HOST_DEVICE RAJA_INLINE TensorRegisterBase(register_type reg0, REGS const &... regs)
Definition: TensorRegisterBase.hpp:185
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator-(self_type const &x) const
Subtract two vector registers.
Definition: TensorRegisterBase.hpp:541
RAJA_HOST_DEVICE static constexpr RAJA_INLINE bool is_root()
Definition: TensorRegisterBase.hpp:196
RAJA_HOST_DEVICE RAJA_INLINE self_type multiply(self_type const &x) const
Definition: TensorRegisterBase.hpp:366
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type & operator+=(element_type x)
Add a scalar to this vector.
Definition: TensorRegisterBase.hpp:516
RAJA_HOST_DEVICE static constexpr RAJA_INLINE TensorRegisterStoreRef< REF_TYPE > create_et_store_ref(REF_TYPE const &ref)
Definition: TensorRegisterBase.hpp:200
camp::idx_t index_type
Definition: TensorRegisterBase.hpp:127
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE static RAJA_INLINE self_type s_load_ref(REF_TYPE const &ref)
Definition: TensorRegisterBase.hpp:207
RAJA_HOST_DEVICE RAJA_INLINE self_type multiply_add(self_type const &x, self_type const &add) const
Definition: TensorRegisterBase.hpp:382
Definition: TensorRegisterBase.hpp:105
Definition: TensorRegisterBase.hpp:96
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_SUPPRESS_HD_WARN
Definition: macros.hpp:68
@ TENSOR_FULL
Definition: TensorRef.hpp:236
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
Definition: TensorLayout.hpp:35
Definition: TensorRegisterBase.hpp:91
static constexpr camp::idx_t value
Definition: TensorRegisterBase.hpp:92
Definition: TensorRef.hpp:309
Definition: TensorRegisterBase.hpp:59
RAJA_HOST_DEVICE static RAJA_INLINE multiply_type multiply(LHS const &lhs, RHS const &rhs)
Definition: TensorRegisterBase.hpp:67
decltype(LHS().multiply(RHS())) multiply_type
Definition: TensorRegisterBase.hpp:61
Definition: TensorRegisterBase.hpp:46
static constexpr camp::idx_t value
Definition: TensorRegisterBase.hpp:47
Definition: TensorRegisterBase.hpp:75
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type operator=(RHS const &rhs)
Definition: TensorRegisterBase.hpp:81
REF_TYPE m_ref
Definition: TensorRegisterBase.hpp:77