RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TensorScalarLiteral.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_ET_ScalarLiteral_HPP
21 #define RAJA_pattern_tensor_ET_ScalarLiteral_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "RAJA/util/macros.hpp"
26 
28 
29 namespace RAJA
30 {
31 namespace internal
32 {
33 namespace expt
34 {
35 
36 
37 namespace ET
38 {
39 
40 
41 template<typename T>
42 class TensorScalarLiteral : public TensorExpressionBase<TensorScalarLiteral<T>>
43 {
44 public:
47  using element_type = T;
48  using result_type = T;
50 
51  static constexpr camp::idx_t s_num_dims = 0;
52 
53  RAJA_INLINE
54 
56  constexpr index_type getDimSize(index_type) const { return 0; }
57 
58  RAJA_INLINE
59 
61  explicit constexpr TensorScalarLiteral(element_type const& value) noexcept
62  : m_value {value}
63  {}
64 
65  template<typename TILE_TYPE>
66  RAJA_INLINE RAJA_HOST_DEVICE element_type eval(TILE_TYPE const&) const
67  {
68  return m_value;
69  }
70 
71  RAJA_INLINE
72 
74  void print_ast() const { printf("ScalarLiteral(%e)", (double)m_value); }
75 
76 private:
77  element_type m_value;
78 };
79 
80 /*
81  * For arithmetic values, we need to wrap in a constant value ET node
82  */
83 template<typename RHS>
85  RHS,
86  typename std::enable_if<std::is_arithmetic<RHS>::value>::type>
87 {
89 
90  RAJA_INLINE
91 
93  static constexpr return_type normalize(RHS const& rhs)
94  {
95  return return_type(rhs);
96  }
97 };
98 
99 
100 } // namespace ET
101 
102 } // namespace expt
103 } // namespace internal
104 
105 } // namespace RAJA
106 
107 
108 #endif
RAJA header file defining SIMD/SIMT register operations.
Definition: TensorRegister.hpp:46
Definition: ExpressionTemplateBase.hpp:72
Definition: TensorScalarLiteral.hpp:43
RAJA_INLINE constexpr RAJA_HOST_DEVICE index_type getDimSize(index_type) const
Definition: TensorScalarLiteral.hpp:56
RAJA_INLINE RAJA_HOST_DEVICE element_type eval(TILE_TYPE const &) const
Definition: TensorScalarLiteral.hpp:66
RAJA_INLINE RAJA_HOST_DEVICE void print_ast() const
Definition: TensorScalarLiteral.hpp:74
static constexpr camp::idx_t s_num_dims
Definition: TensorScalarLiteral.hpp:51
T result_type
Definition: TensorScalarLiteral.hpp:48
T element_type
Definition: TensorScalarLiteral.hpp:47
RAJA_INLINE constexpr RAJA_HOST_DEVICE TensorScalarLiteral(element_type const &value) noexcept
Definition: TensorScalarLiteral.hpp:61
RAJA::Index_type index_type
Definition: TensorScalarLiteral.hpp:49
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
std::ptrdiff_t Index_type
Definition: types.hpp:226
Definition: ListSegment.hpp:416
RAJA_INLINE static constexpr RAJA_HOST_DEVICE return_type normalize(RHS const &rhs)
Definition: TensorScalarLiteral.hpp:93
Definition: normalizeOperand.hpp:44