RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TensorMultiplyAdd.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_ET_TensorMultiplyAddAdd_HPP
21 #define RAJA_pattern_tensor_ET_TensorMultiplyAddAdd_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "RAJA/util/macros.hpp"
26 
29 
30 namespace RAJA
31 {
32 namespace internal
33 {
34 namespace expt
35 {
36 
37 
38 namespace ET
39 {
40 
41 
49 template<typename LEFT_OPERAND_TYPE,
50  typename RIGHT_OPERAND_TYPE,
51  typename ADD_OPERAND_TYPE>
53  : public TensorExpressionBase<TensorMultiplyAdd<LEFT_OPERAND_TYPE,
54  RIGHT_OPERAND_TYPE,
55  ADD_OPERAND_TYPE>>
56 {
57 public:
58  using self_type = TensorMultiplyAdd<LEFT_OPERAND_TYPE,
59  RIGHT_OPERAND_TYPE,
60  ADD_OPERAND_TYPE>;
61  using left_operand_type = LEFT_OPERAND_TYPE;
62  using right_operand_type = RIGHT_OPERAND_TYPE;
63  using add_operand_type = ADD_OPERAND_TYPE;
65 
66  using element_type = typename LEFT_OPERAND_TYPE::element_type;
67  using index_type = typename LEFT_OPERAND_TYPE::index_type;
68 
70  static constexpr camp::idx_t s_num_dims = multiply_op::s_num_dims;
71 
72 private:
73  left_operand_type m_left_operand;
74  right_operand_type m_right_operand;
75  add_operand_type m_add_operand;
76 
77 public:
78  RAJA_INLINE
79 
81  TensorMultiplyAdd(left_operand_type const& left_operand,
82  right_operand_type const& right_operand,
83  add_operand_type const& add_operand)
84  : m_left_operand {left_operand},
85  m_right_operand {right_operand},
86  m_add_operand {add_operand}
87  {}
88 
89  template<typename TILE_TYPE>
90  RAJA_INLINE RAJA_HOST_DEVICE auto eval(TILE_TYPE const& tile) const
91  -> decltype(multiply_op::multiply_add(tile,
92  m_left_operand,
93  m_right_operand,
94  m_add_operand))
95  {
96  return multiply_op::multiply_add(tile, m_left_operand, m_right_operand,
97  m_add_operand);
98  }
99 
100  RAJA_INLINE
101 
103  void print_ast() const
104  {
105  printf("MultiplyAdd[");
107  printf("](");
108  m_left_operand.print_ast();
109  printf(", ");
110  m_right_operand.print_ast();
111  printf(", ");
112  m_add_operand.print_ast();
113  printf(")");
114  }
115 };
116 
117 
118 } // namespace ET
119 
120 } // namespace expt
121 } // namespace internal
122 
123 } // namespace RAJA
124 
125 
126 #endif
RAJA header file defining SIMD/SIMT register operations.
RAJA header defining expression template behavior for operator*.
Definition: ExpressionTemplateBase.hpp:72
Definition: TensorMultiplyAdd.hpp:56
ADD_OPERAND_TYPE add_operand_type
Definition: TensorMultiplyAdd.hpp:63
RAJA_INLINE RAJA_HOST_DEVICE auto eval(TILE_TYPE const &tile) const -> decltype(multiply_op::multiply_add(tile, m_left_operand, m_right_operand, m_add_operand))
Definition: TensorMultiplyAdd.hpp:90
typename multiply_op::result_type result_type
Definition: TensorMultiplyAdd.hpp:69
LEFT_OPERAND_TYPE left_operand_type
Definition: TensorMultiplyAdd.hpp:61
RIGHT_OPERAND_TYPE right_operand_type
Definition: TensorMultiplyAdd.hpp:62
RAJA_INLINE RAJA_HOST_DEVICE TensorMultiplyAdd(left_operand_type const &left_operand, right_operand_type const &right_operand, add_operand_type const &add_operand)
Definition: TensorMultiplyAdd.hpp:81
typename LEFT_OPERAND_TYPE::element_type element_type
Definition: TensorMultiplyAdd.hpp:66
static constexpr camp::idx_t s_num_dims
Definition: TensorMultiplyAdd.hpp:70
RAJA_INLINE RAJA_HOST_DEVICE void print_ast() const
Definition: TensorMultiplyAdd.hpp:103
typename LEFT_OPERAND_TYPE::index_type index_type
Definition: TensorMultiplyAdd.hpp:67
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE RAJA_INLINE void tile(CONTEXT const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch_core.hpp:589
Definition: MultiplyOperator.hpp:48
static constexpr camp::idx_t s_num_dims
Definition: MultiplyOperator.hpp:51
typename LEFT_OPERAND_TYPE::result_type result_type
Definition: MultiplyOperator.hpp:50
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:56
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_OPERAND_TYPE const &add) -> decltype(left.eval(tile).multiply_add(right.eval(tile), add.eval(tile)))
Definition: MultiplyOperator.hpp:89