RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TensorLiteral.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_ET_TensorLiteral_HPP
21 #define RAJA_pattern_tensor_ET_TensorLiteral_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "RAJA/util/macros.hpp"
26 
28 
29 namespace RAJA
30 {
31 namespace internal
32 {
33 namespace expt
34 {
35 
36 
37 namespace ET
38 {
39 
40 
41 template<typename TENSOR_TYPE>
42 class TensorLiteral : public TensorExpressionBase<TensorLiteral<TENSOR_TYPE>>
43 {
44 public:
46  using tensor_type = TENSOR_TYPE;
47  using element_type = typename TENSOR_TYPE::element_type;
50 
51  static constexpr camp::idx_t s_num_dims = result_type::s_num_dims;
52 
53  RAJA_INLINE
54 
56  constexpr index_type getDimSize(index_type dim) const
57  {
58  return tensor_type::s_dim_elem(dim);
59  }
60 
61  RAJA_INLINE
62 
64  explicit TensorLiteral(tensor_type const& value) : m_value {value} {}
65 
66  template<typename TILE_TYPE>
67  RAJA_INLINE RAJA_HOST_DEVICE result_type eval(TILE_TYPE const&) const
68  {
69  return result_type(m_value);
70  }
71 
72  RAJA_INLINE
73 
75  void print_ast() const { printf("TensorLiteral()"); }
76 
77 private:
78  tensor_type m_value;
79 };
80 
81 /*
82  * For TensorRegister nodes, we need to wrap this in a constant value ET node
83  */
84 template<typename RHS>
86  RHS,
87  typename std::enable_if<
88  std::is_base_of<TensorRegisterConcreteBase, RHS>::value>::type>
89 {
91 
92  RAJA_INLINE
93 
95  static constexpr return_type normalize(RHS const& rhs)
96  {
97  return return_type(rhs);
98  }
99 };
100 
101 } // namespace ET
102 
103 } // namespace expt
104 } // namespace internal
105 
106 } // namespace RAJA
107 
108 
109 #endif
RAJA header file defining SIMD/SIMT register operations.
Definition: ExpressionTemplateBase.hpp:72
Definition: TensorLiteral.hpp:43
TENSOR_TYPE tensor_type
Definition: TensorLiteral.hpp:46
RAJA_INLINE RAJA_HOST_DEVICE void print_ast() const
Definition: TensorLiteral.hpp:75
RAJA_INLINE RAJA_HOST_DEVICE TensorLiteral(tensor_type const &value)
Definition: TensorLiteral.hpp:64
tensor_type result_type
Definition: TensorLiteral.hpp:48
typename TENSOR_TYPE::element_type element_type
Definition: TensorLiteral.hpp:47
RAJA_INLINE constexpr RAJA_HOST_DEVICE index_type getDimSize(index_type dim) const
Definition: TensorLiteral.hpp:56
RAJA_INLINE RAJA_HOST_DEVICE result_type eval(TILE_TYPE const &) const
Definition: TensorLiteral.hpp:67
RAJA::Index_type index_type
Definition: TensorLiteral.hpp:49
static constexpr camp::idx_t s_num_dims
Definition: TensorLiteral.hpp:51
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
std::ptrdiff_t Index_type
Definition: types.hpp:226
Definition: ListSegment.hpp:416
RAJA_INLINE static constexpr RAJA_HOST_DEVICE return_type normalize(RHS const &rhs)
Definition: TensorLiteral.hpp:95
Definition: normalizeOperand.hpp:44