RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TensorTranspose.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_ET_TensorTranspose_HPP
21 #define RAJA_pattern_tensor_ET_TensorTranspose_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "RAJA/util/macros.hpp"
26 
28 
29 namespace RAJA
30 {
31 namespace internal
32 {
33 namespace expt
34 {
35 
36 
37 namespace ET
38 {
39 
40 template<typename ET_TYPE>
41 class TensorTranspose : public TensorExpressionBase<TensorTranspose<ET_TYPE>>
42 {
43 public:
45  using rhs_type = ET_TYPE;
46  using tensor_type = typename ET_TYPE::result_type;
47  using element_type = typename tensor_type::element_type;
48  using index_type = typename ET_TYPE::index_type;
49 
51  using tile_type = typename ET_TYPE::tile_type;
52  static constexpr camp::idx_t s_num_dims = ET_TYPE::s_num_dims;
53 
54  RAJA_INLINE
55 
57  TensorTranspose(rhs_type const& tensor) : m_tensor {tensor} {}
58 
59  RAJA_INLINE
60 
62  constexpr index_type getDimSize(index_type dim) const
63  {
64  return m_tensor.getDimSize(dim);
65  }
66 
67  template<typename TILE_TYPE>
68  RAJA_INLINE RAJA_HOST_DEVICE result_type eval(TILE_TYPE const& tile) const
69  {
70  // transpose which tile we are returning
71  TILE_TYPE trans_tile {{tile.m_begin[1], tile.m_begin[0]},
72  {tile.m_size[1], tile.m_size[0]}};
73 
74  // evaluate and return the transposed tile
75  return m_tensor.eval(trans_tile).transpose();
76  }
77 
78  RAJA_INLINE
79 
81  void print_ast() const
82  {
83  printf("Transpose(");
84  m_tensor.print_ast();
85  printf(")");
86  }
87 
88 private:
89  rhs_type m_tensor;
90 };
91 
92 
93 } // namespace ET
94 
95 } // namespace expt
96 } // namespace internal
97 
98 } // namespace RAJA
99 
100 
101 #endif
RAJA header file defining SIMD/SIMT register operations.
Definition: ExpressionTemplateBase.hpp:72
Definition: TensorTranspose.hpp:42
static constexpr camp::idx_t s_num_dims
Definition: TensorTranspose.hpp:52
typename tensor_type::element_type element_type
Definition: TensorTranspose.hpp:47
typename ET_TYPE::tile_type tile_type
Definition: TensorTranspose.hpp:51
typename ET_TYPE::index_type index_type
Definition: TensorTranspose.hpp:48
RAJA_INLINE constexpr RAJA_HOST_DEVICE index_type getDimSize(index_type dim) const
Definition: TensorTranspose.hpp:62
typename ET_TYPE::result_type tensor_type
Definition: TensorTranspose.hpp:46
tensor_type result_type
Definition: TensorTranspose.hpp:50
RAJA_INLINE RAJA_HOST_DEVICE result_type eval(TILE_TYPE const &tile) const
Definition: TensorTranspose.hpp:68
ET_TYPE rhs_type
Definition: TensorTranspose.hpp:45
RAJA_INLINE RAJA_HOST_DEVICE void print_ast() const
Definition: TensorTranspose.hpp:81
RAJA_INLINE RAJA_HOST_DEVICE TensorTranspose(rhs_type const &tensor)
Definition: TensorTranspose.hpp:57
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE RAJA_INLINE void tile(CONTEXT const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch_core.hpp:589