RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TensorIndexTraits.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_TensorIndexTraits_HPP
21 #define RAJA_pattern_tensor_TensorIndexTraits_HPP
22 
23 #include "RAJA/config.hpp"
24 #include "RAJA/util/macros.hpp"
26 
27 namespace RAJA
28 {
29 
30 namespace internal
31 {
32 /* Partial specialization for the strip_index_type_t helper in
33  IndexValue.hpp
34 */
35 template<typename IDX, typename VECTOR_TYPE, camp::idx_t DIM>
36 struct StripIndexTypeT<RAJA::expt::TensorIndex<IDX, VECTOR_TYPE, DIM>>
37 {
38  using type =
40 };
41 
42 namespace expt
43 {
44 
45 
46 // Helper that strips the Vector type from an argument
47 template<typename ARG>
49 {
50  using arg_type = ARG;
52 
53  RAJA_INLINE
54 
56  static constexpr bool isTensorIndex() { return false; }
57 
58  RAJA_INLINE
59 
61  static constexpr arg_type const& strip(arg_type const& arg) { return arg; }
62 
63  RAJA_INLINE
64 
66  static constexpr arg_type const strip_by_value(arg_type const arg)
67  {
68  return arg;
69  }
70 
71  RAJA_INLINE
72 
74  static constexpr value_type size(arg_type const&) { return 1; }
75 
76  RAJA_INLINE
77 
79  static constexpr value_type begin(arg_type const&) { return 0; }
80 
81  RAJA_INLINE
82 
84  static constexpr value_type dim() { return 0; }
85 
86  RAJA_INLINE
87 
89  static constexpr value_type num_elem() { return 1; }
90 };
91 
92 template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM>
93 struct TensorIndexTraits<RAJA::expt::TensorIndex<IDX, TENSOR_TYPE, DIM>>
94 {
96  using arg_type = IDX;
98 
99  RAJA_INLINE
100 
102  static constexpr bool isTensorIndex() { return true; }
103 
104  RAJA_INLINE
105 
107  static constexpr arg_type const& strip(index_type const& arg) { return *arg; }
108 
109  RAJA_INLINE
110 
112  static constexpr arg_type const strip_by_value(index_type const arg)
113  {
114  return (arg_type)arg;
115  }
116 
117  RAJA_INLINE
118 
120  static constexpr value_type size(index_type const& arg) { return arg.size(); }
121 
122  RAJA_INLINE
123 
125  static constexpr value_type begin(index_type const& arg)
126  {
127  return arg.begin();
128  }
129 
130  RAJA_INLINE
131 
133  static constexpr value_type dim() { return DIM; }
134 
135  RAJA_INLINE
136 
138  static constexpr value_type num_elem()
139  {
140  return TENSOR_TYPE::s_dim_elem(DIM);
141  }
142 };
143 
144 template<typename IDX,
145  typename TENSOR_TYPE,
146  camp::idx_t DIM,
147  IDX INDEX_VALUE,
148  strip_index_type_t<IDX> LENGTH_VALUE>
150  RAJA::expt::StaticTensorIndexInner<IDX,
151  TENSOR_TYPE,
152  DIM,
153  INDEX_VALUE,
154  LENGTH_VALUE>>>
155 {
159  TENSOR_TYPE,
160  DIM,
161  INDEX_VALUE,
162  LENGTH_VALUE>>;
163  using arg_type = IDX;
165 
166  RAJA_INLINE
167 
169  static constexpr bool isTensorIndex() { return true; }
170 
171  RAJA_INLINE
172 
174  static constexpr arg_type const strip_by_value(index_type const)
175  {
176  return INDEX_VALUE;
177  }
178 
179  RAJA_INLINE
180 
182  static constexpr value_type size(index_type const&) { return LENGTH_VALUE; }
183 
184  RAJA_INLINE
185 
187  static constexpr value_type begin(index_type const&) { return INDEX_VALUE; }
188 
189  RAJA_INLINE
190 
192  static constexpr value_type dim() { return DIM; }
193 
194  RAJA_INLINE
195 
197  static constexpr value_type num_elem()
198  {
199  return TENSOR_TYPE::s_dim_elem(DIM);
200  }
201 };
202 
203 /*
204  * Returns vector size of argument.
205  *
206  * For scalars, always returns 1.
207  *
208  * For VectorIndex types, returns the number of vector lanes.
209  */
210 template<typename ARG>
211 RAJA_INLINE RAJA_HOST_DEVICE constexpr bool isTensorIndex()
212 {
214 }
215 
216 template<typename ARG>
217 RAJA_INLINE RAJA_HOST_DEVICE constexpr auto stripTensorIndex(ARG const& arg) ->
218  typename TensorIndexTraits<ARG>::arg_type const&
219 {
220  return TensorIndexTraits<ARG>::strip(arg);
221 }
222 
223 template<typename ARG>
224 RAJA_INLINE RAJA_HOST_DEVICE constexpr auto stripTensorIndexByValue(
225  ARG const arg) -> typename TensorIndexTraits<ARG>::arg_type const
226 {
228 }
229 
230 /*
231  * Returns tensor dimension size of argument.
232  *
233  * For VectorIndex types, returns the number of vector lanes.
234  */
235 template<typename ARG, typename IDX>
236 RAJA_INLINE RAJA_HOST_DEVICE constexpr IDX getTensorSize(ARG const& arg,
237  IDX dim_size)
238 {
239  return TensorIndexTraits<ARG>::size(arg) >= 0
240  ? IDX(TensorIndexTraits<ARG>::size(arg))
241  : dim_size;
242 }
243 
244 /*
245  * Returns tensor dimenson beginning index of an argument.
246  *
247  */
248 template<typename ARG, typename IDX>
249 RAJA_INLINE RAJA_HOST_DEVICE constexpr IDX getTensorBegin(ARG const& arg,
250  IDX dim_minval)
251 {
252  return TensorIndexTraits<ARG>::begin(arg) >= 0
254  : dim_minval;
255 }
256 
257 /*
258  * Returns vector dim of argument.
259  *
260  * For scalars, always returns 0.
261  *
262  * For VectorIndex types, returns the DIM argument.
263  * For vector_exec, this is always 0
264  *
265  * For matrices, DIM means:
266  * 0 : Row
267  * 1 : Column
268  */
269 template<typename ARG>
270 RAJA_INLINE RAJA_HOST_DEVICE constexpr auto getTensorDim()
271  -> decltype(TensorIndexTraits<ARG>::dim())
272 {
274 }
275 
276 } // namespace expt
277 
278 /*
279  * Lambda<N, Seg<X>> overload that matches VectorIndex types, and properly
280  * includes the vector length with them
281  */
282 template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM, camp::idx_t id>
283 struct LambdaSegExtractor<RAJA::expt::TensorIndex<IDX, TENSOR_TYPE, DIM>, id>
284 {
285 
286  template<typename Data>
287  RAJA_HOST_DEVICE RAJA_INLINE constexpr static RAJA::expt::
288  TensorIndex<IDX, TENSOR_TYPE, DIM>
289  extract(Data&& data)
290  {
292  camp::get<id>(data.segment_tuple)
293  .begin()[camp::get<id>(data.offset_tuple)],
294  camp::get<id>(data.vector_sizes));
295  }
296 };
297 
298 /*
299  * Lambda<N, Seg<X>> overload that matches VectorIndex types, and properly
300  * includes the vector length with them
301  */
302 template<typename IDX, typename TENSOR_TYPE, camp::idx_t DIM, camp::idx_t id>
303 struct LambdaOffsetExtractor<RAJA::expt::TensorIndex<IDX, TENSOR_TYPE, DIM>, id>
304 {
305 
306  template<typename Data>
307  RAJA_HOST_DEVICE RAJA_INLINE constexpr static RAJA::expt::
308  TensorIndex<IDX, TENSOR_TYPE, DIM>
309  extract(Data&& data)
310  {
312  IDX(camp::get<id>(data.offset_tuple)), // convert offset type to IDX
313  camp::get<id>(data.vector_sizes));
314  }
315 };
316 
317 } // namespace internal
318 } // namespace RAJA
319 
320 
321 #endif
RAJA header file defining SIMD/SIMT register operations.
Definition: TensorIndex.hpp:45
RAJA_INLINE constexpr RAJA_HOST_DEVICE value_type size() const
Definition: TensorIndex.hpp:151
strip_index_type_t< IDX > value_type
Definition: TensorIndex.hpp:48
RAJA_INLINE constexpr RAJA_HOST_DEVICE index_type begin() const
Definition: TensorIndex.hpp:146
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
RAJA_INLINE constexpr RAJA_HOST_DEVICE auto stripTensorIndexByValue(ARG const arg) -> typename TensorIndexTraits< ARG >::arg_type const
Definition: TensorIndexTraits.hpp:224
RAJA_INLINE constexpr RAJA_HOST_DEVICE IDX getTensorBegin(ARG const &arg, IDX dim_minval)
Definition: TensorIndexTraits.hpp:249
RAJA_INLINE constexpr RAJA_HOST_DEVICE auto stripTensorIndex(ARG const &arg) -> typename TensorIndexTraits< ARG >::arg_type const &
Definition: TensorIndexTraits.hpp:217
RAJA_INLINE constexpr RAJA_HOST_DEVICE IDX getTensorSize(ARG const &arg, IDX dim_size)
Definition: TensorIndexTraits.hpp:236
RAJA_INLINE constexpr RAJA_HOST_DEVICE auto getTensorDim() -> decltype(TensorIndexTraits< ARG >::dim())
Definition: TensorIndexTraits.hpp:270
RAJA_INLINE constexpr RAJA_HOST_DEVICE bool isTensorIndex()
Definition: TensorIndexTraits.hpp:211
Definition: AlignedRangeIndexSetBuilders.cpp:35
typename internal::StripIndexTypeT< FROM >::type strip_index_type_t
Strips a strongly typed index to its underlying type In the case of a non-strongly typed index,...
Definition: IndexValue.hpp:364
Definition: TensorIndex.hpp:38
Definition: TensorIndex.hpp:41
RAJA_HOST_DEVICE constexpr static RAJA_INLINE RAJA::expt::TensorIndex< IDX, TENSOR_TYPE, DIM > extract(Data &&data)
Definition: TensorIndexTraits.hpp:309
Definition: Lambda.hpp:174
RAJA_HOST_DEVICE constexpr static RAJA_INLINE RAJA::expt::TensorIndex< IDX, TENSOR_TYPE, DIM > extract(Data &&data)
Definition: TensorIndexTraits.hpp:289
Definition: Lambda.hpp:149
typename RAJA::expt::TensorIndex< IDX, VECTOR_TYPE, DIM >::value_type type
Definition: TensorIndexTraits.hpp:39
Definition: IndexValue.hpp:344
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type size(index_type const &)
Definition: TensorIndexTraits.hpp:182
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type begin(index_type const &)
Definition: TensorIndexTraits.hpp:187
RAJA_INLINE static constexpr RAJA_HOST_DEVICE arg_type const strip_by_value(index_type const)
Definition: TensorIndexTraits.hpp:174
RAJA_INLINE static constexpr RAJA_HOST_DEVICE arg_type const & strip(index_type const &arg)
Definition: TensorIndexTraits.hpp:107
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type dim()
Definition: TensorIndexTraits.hpp:133
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type begin(index_type const &arg)
Definition: TensorIndexTraits.hpp:125
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type num_elem()
Definition: TensorIndexTraits.hpp:138
strip_index_type_t< IDX > value_type
Definition: TensorIndexTraits.hpp:97
RAJA_INLINE static constexpr RAJA_HOST_DEVICE arg_type const strip_by_value(index_type const arg)
Definition: TensorIndexTraits.hpp:112
RAJA_INLINE static constexpr RAJA_HOST_DEVICE bool isTensorIndex()
Definition: TensorIndexTraits.hpp:102
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type size(index_type const &arg)
Definition: TensorIndexTraits.hpp:120
Definition: TensorIndexTraits.hpp:49
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type size(arg_type const &)
Definition: TensorIndexTraits.hpp:74
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type num_elem()
Definition: TensorIndexTraits.hpp:89
ARG arg_type
Definition: TensorIndexTraits.hpp:50
RAJA_INLINE static constexpr RAJA_HOST_DEVICE arg_type const strip_by_value(arg_type const arg)
Definition: TensorIndexTraits.hpp:66
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type begin(arg_type const &)
Definition: TensorIndexTraits.hpp:79
RAJA_INLINE static constexpr RAJA_HOST_DEVICE arg_type const & strip(arg_type const &arg)
Definition: TensorIndexTraits.hpp:61
RAJA_INLINE static constexpr RAJA_HOST_DEVICE value_type dim()
Definition: TensorIndexTraits.hpp:84
RAJA_INLINE static constexpr RAJA_HOST_DEVICE bool isTensorIndex()
Definition: TensorIndexTraits.hpp:56
strip_index_type_t< ARG > value_type
Definition: TensorIndexTraits.hpp:51