RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TensorMultiply.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_ET_TensorMultiply_HPP
21 #define RAJA_pattern_tensor_ET_TensorMultiply_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "RAJA/util/macros.hpp"
26 
29 
30 namespace RAJA
31 {
32 namespace internal
33 {
34 namespace expt
35 {
36 
37 namespace ET
38 {
39 
40 // forward decl for FMA contraction
41 template<typename LEFT_OPERAND_TYPE,
42  typename RIGHT_OPERAND_TYPE,
43  typename ADD_TYPE>
44 class TensorMultiplyAdd;
45 
46 template<typename LEFT_OPERAND_TYPE, typename RIGHT_OPERAND_TYPE>
48  : public TensorExpressionBase<
49  TensorMultiply<LEFT_OPERAND_TYPE, RIGHT_OPERAND_TYPE>>
50 {
51 public:
53  using left_operand_type = LEFT_OPERAND_TYPE;
54  using right_operand_type = RIGHT_OPERAND_TYPE;
56 
57  using element_type = typename LEFT_OPERAND_TYPE::element_type;
58  using index_type = typename LEFT_OPERAND_TYPE::index_type;
59 
61  static constexpr camp::idx_t s_num_dims = multiply_op::s_num_dims;
62 
63 private:
64  left_operand_type m_left_operand;
65  right_operand_type m_right_operand;
66 
67 public:
68  RAJA_INLINE
69 
71  TensorMultiply(left_operand_type const& left_operand,
72  right_operand_type const& right_operand)
73  : m_left_operand {left_operand},
74  m_right_operand {right_operand}
75  {}
76 
77  RAJA_INLINE
78 
80  constexpr int getDimSize(int dim) const
81  {
82  return multiply_op::getDimSize(dim, m_left_operand, m_right_operand);
83  }
84 
85  template<typename TILE_TYPE>
86  RAJA_INLINE RAJA_HOST_DEVICE auto eval(TILE_TYPE const& tile) const
87  -> decltype(multiply_op::multiply(tile, m_left_operand, m_right_operand))
88  {
89  return multiply_op::multiply(tile, m_left_operand, m_right_operand);
90  }
91 
95  RAJA_INLINE
96 
98  constexpr left_operand_type const& getLeftOperand() const
99  {
100  return m_left_operand;
101  }
102 
106  RAJA_INLINE
107 
109  constexpr right_operand_type const& getRightOperand() const
110  {
111  return m_right_operand;
112  }
113 
118  template<typename ADD>
122  operator+(ADD const& add) const
123  {
126  m_left_operand, m_right_operand, normalizeOperand(add));
127  }
128 
129  RAJA_INLINE
130 
132  void print_ast() const
133  {
134  printf("Multiply[");
136  printf("](");
137  m_left_operand.print_ast();
138  printf(", ");
139  m_right_operand.print_ast();
140  printf(")");
141  }
142 };
143 
144 /*
145  * Overload for: arithmetic * tensorexpression
146 
147  */
148 template<
149  typename LHS,
150  typename RHS,
151  typename std::enable_if<std::is_arithmetic<LHS>::value, bool>::type = true,
152  typename std::enable_if<
153  std::is_base_of<TensorExpressionConcreteBase, RHS>::value,
154  bool>::type = true>
155 RAJA_INLINE RAJA_HOST_DEVICE auto operator*(LHS const& left_operand,
156  RHS const& right_operand)
158 {
160  NormalizeOperandHelper<LHS>::normalize(left_operand), right_operand);
161 }
162 
163 } // namespace ET
164 
165 } // namespace expt
166 } // namespace internal
167 
168 } // namespace RAJA
169 
170 
171 #endif
RAJA header file defining SIMD/SIMT register operations.
RAJA header defining expression template behavior for operator*.
Definition: ExpressionTemplateBase.hpp:72
Definition: TensorMultiplyAdd.hpp:56
Definition: TensorMultiply.hpp:50
typename LEFT_OPERAND_TYPE::index_type index_type
Definition: TensorMultiply.hpp:58
RAJA_INLINE RAJA_HOST_DEVICE auto eval(TILE_TYPE const &tile) const -> decltype(multiply_op::multiply(tile, m_left_operand, m_right_operand))
Definition: TensorMultiply.hpp:86
typename LEFT_OPERAND_TYPE::element_type element_type
Definition: TensorMultiply.hpp:57
typename multiply_op::result_type result_type
Definition: TensorMultiply.hpp:60
RAJA_INLINE constexpr RAJA_HOST_DEVICE left_operand_type const & getLeftOperand() const
Definition: TensorMultiply.hpp:98
RAJA_INLINE constexpr RAJA_HOST_DEVICE int getDimSize(int dim) const
Definition: TensorMultiply.hpp:80
RIGHT_OPERAND_TYPE right_operand_type
Definition: TensorMultiply.hpp:54
RAJA_INLINE constexpr RAJA_HOST_DEVICE right_operand_type const & getRightOperand() const
Definition: TensorMultiply.hpp:109
LEFT_OPERAND_TYPE left_operand_type
Definition: TensorMultiply.hpp:53
static constexpr camp::idx_t s_num_dims
Definition: TensorMultiply.hpp:61
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE TensorMultiplyAdd< left_operand_type, right_operand_type, normalize_operand_t< ADD > > operator+(ADD const &add) const
Definition: TensorMultiply.hpp:122
RAJA_INLINE RAJA_HOST_DEVICE void print_ast() const
Definition: TensorMultiply.hpp:132
RAJA_INLINE RAJA_HOST_DEVICE TensorMultiply(left_operand_type const &left_operand, right_operand_type const &right_operand)
Definition: TensorMultiply.hpp:71
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_SUPPRESS_HD_WARN
Definition: macros.hpp:68
RAJA_INLINE RAJA_HOST_DEVICE auto normalizeOperand(RHS const &rhs) -> typename NormalizeOperandHelper< RHS >::return_type
Definition: normalizeOperand.hpp:73
RAJA_INLINE RAJA_HOST_DEVICE auto operator*(LHS const &left_operand, RHS const &right_operand) -> TensorMultiply< typename NormalizeOperandHelper< LHS >::return_type, RHS >
Definition: TensorMultiply.hpp:155
typename NormalizeOperandHelper< RHS >::return_type normalize_operand_t
Definition: normalizeOperand.hpp:80
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE RAJA_INLINE void tile(CONTEXT const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch_core.hpp:589
Definition: MultiplyOperator.hpp:48
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right) -> decltype(left.eval(tile) *right.eval(tile))
Definition: MultiplyOperator.hpp:76
static constexpr camp::idx_t s_num_dims
Definition: MultiplyOperator.hpp:51
typename LEFT_OPERAND_TYPE::result_type result_type
Definition: MultiplyOperator.hpp:50
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:56
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:65
Definition: normalizeOperand.hpp:44