20 #ifndef RAJA_pattern_tensor_tensorref_HPP
21 #define RAJA_pattern_tensor_tensorref_HPP
23 #include "RAJA/config.hpp"
34 template<
typename INT_SEQ>
37 template<
typename INDEX_TYPE, INDEX_TYPE NEW_HEAD,
typename ARRAY>
40 template<
typename INDEX_TYPE,
size_t IDX, INDEX_TYPE DELTA,
typename ARRAY>
43 template<
typename INDEX_TYPE,
size_t IDX, INDEX_TYPE DELTA,
typename ARRAY>
46 template<
typename INDEX_TYPE, INDEX_TYPE HEAD, INDEX_TYPE... TAIL>
50 using seq_type = camp::int_seq<INDEX_TYPE, HEAD, TAIL...>;
63 static constexpr INDEX_TYPE
value_at(
size_t index)
71 return Tail::value_at(index - 1);
86 return tail[index - 1];
95 printf(
"%ld ", (
long)HEAD);
110 template<
typename INDEX_TYPE>
123 static constexpr INDEX_TYPE
value_at(
size_t) {
return 0; }
141 template<
typename INDEX_TYPE, INDEX_TYPE NEW_HEAD, INDEX_TYPE... ORIG_INTS>
149 using Seq = camp::int_seq<INDEX_TYPE, NEW_HEAD, ORIG_INTS...>;
152 template<
typename INDEX_TYPE,
173 template<
typename INDEX_TYPE,
193 template<
typename INDEX_TYPE,
214 template<
typename INDEX_TYPE,
240 template<
typename INDEX_TYPE, TensorTileSize TENSOR_SIZE, camp::
idx_t NUM_DIMS>
252 template<
typename I, TensorTileSize S>
255 for (camp::idx_t i = 0; i < NUM_DIMS; ++i)
269 template<
typename INDEX_TYPE2, TensorTileSize TENSOR_SIZE2>
286 printf(
"TensorTile: dims=%d, m_begin=[", (
int)NUM_DIMS);
288 for (camp::idx_t i = 0; i < NUM_DIMS; ++i)
290 printf(
"%ld ", (
long)
m_begin[i]);
293 printf(
"], m_size=[");
295 for (camp::idx_t i = 0; i < NUM_DIMS; ++i)
297 printf(
"%ld ", (
long)
m_size[i]);
305 template<
typename INDEX_TYPE,
311 template<
typename INDEX_TYPE,
313 INDEX_TYPE... BeginInts,
314 INDEX_TYPE... SizeInts>
317 camp::int_seq<INDEX_TYPE, BeginInts...>,
318 camp::int_seq<INDEX_TYPE, SizeInts...>>
322 using begin_seq = camp::int_seq<INDEX_TYPE, BeginInts...>;
323 using size_seq = camp::int_seq<INDEX_TYPE, SizeInts...>;
331 TensorTile<INDEX_TYPE, TENSOR_SIZE,
sizeof...(BeginInts)>;
340 static_assert(
sizeof...(BeginInts) ==
sizeof...(SizeInts),
341 "Mismatch between number of elements in Begin and Size series "
342 "of StaticTensorTile");
344 static constexpr camp::idx_t s_num_dims =
sizeof...(BeginInts);
354 template<TensorTileSize S>
364 printf(
"StaticTensorTile: dims=%d, m_begin=", (
int)s_num_dims);
376 template<
typename TILE,
typename VALUE,
size_t IDX>
379 template<
typename INDEX_TYPE,
387 camp::integral_constant<INDEX_TYPE, VALUE>,
398 template<
typename TILE,
typename VALUE,
size_t IDX>
401 template<
typename INDEX_TYPE,
409 camp::integral_constant<INDEX_TYPE, VALUE>,
420 template<
typename POINTER_TYPE,
423 camp::idx_t NUM_DIMS,
424 camp::idx_t STRIDE_ONE_DIM = -1>
450 printf(
"TensorRef: dims=%d, m_pointer=%p, m_stride=[", (
int)NUM_DIMS,
453 for (camp::idx_t i = 0; i < NUM_DIMS; ++i)
458 printf(
"], stride_one_dim=%d\n", (
int)STRIDE_ONE_DIM);
465 template<
typename POINTER_TYPE,
468 typename STRIDE_TYPE,
471 camp::idx_t STRIDE_ONE_DIM = -1>
474 template<
typename POINTER_TYPE,
477 INDEX_TYPE... StrideInts,
478 INDEX_TYPE... BeginInts,
479 INDEX_TYPE... SizeInts,
480 camp::idx_t STRIDE_ONE_DIM>
484 camp::int_seq<INDEX_TYPE, StrideInts...>,
485 camp::int_seq<INDEX_TYPE, BeginInts...>,
486 camp::int_seq<INDEX_TYPE, SizeInts...>,
490 static constexpr camp::idx_t s_num_dims =
sizeof...(BeginInts);
491 static constexpr camp::idx_t s_stride_one_dim = STRIDE_ONE_DIM;
497 using begin_seq = camp::int_seq<INDEX_TYPE, BeginInts...>;
498 using size_seq = camp::int_seq<INDEX_TYPE, SizeInts...>;
502 static_assert((
sizeof...(BeginInts) ==
sizeof...(SizeInts)) &&
503 (
sizeof...(SizeInts) ==
sizeof...(StrideInts)),
504 "Mismatch between number of elements in Begin and Size series "
505 "of StaticTensorRef");
527 printf(
"StaticTensorRef: dims=%d, m_pointer=%p, m_stride=", (
int)s_num_dims,
532 printf(
", stride_one_dim=%d\n", (
int)STRIDE_ONE_DIM);
539 template<
typename REF_TYPE,
typename TILE_TYPE,
typename DIM_SEQ>
542 template<
typename REF_TYPE,
typename TILE_TYPE, camp::idx_t... DIM_SEQ>
547 REF_TYPE::s_num_dims == TILE_TYPE::s_num_dims,
548 "Merging a ref with a tile requires an equivalent number of dimensions.");
550 static constexpr camp::idx_t s_num_dims = REF_TYPE::s_num_dims;
551 static constexpr camp::idx_t s_stride_one_dim = REF_TYPE::s_stride_one_dim;
579 TILE_TYPE
const& tile_origin)
582 ref.
m_pointer - RAJA::sum<camp::idx_t>((tile_origin.m_begin[DIM_SEQ] *
583 ref.m_stride[DIM_SEQ])...),
589 template<
typename POINTER_TYPE,
590 typename INDEX_TYPE1,
593 INDEX_TYPE1... BEGIN1,
594 INDEX_TYPE1... SIZE1,
595 camp::idx_t STRIDE_ONE_DIM,
596 typename INDEX_TYPE2,
600 camp::idx_t... DIM_SEQ>
605 camp::int_seq<INDEX_TYPE1, BEGIN1...>,
606 camp::int_seq<INDEX_TYPE1, SIZE1...>,
609 camp::idx_seq<DIM_SEQ...>>
614 camp::int_seq<INDEX_TYPE1, BEGIN1...>,
615 camp::int_seq<INDEX_TYPE1, SIZE1...>>;
621 camp::int_seq<INDEX_TYPE1, BEGIN1...>,
622 camp::int_seq<INDEX_TYPE1, SIZE1...>,
630 camp::int_seq<INDEX_TYPE2,
631 INDEX_TYPE2(ref_stride_type::value_at(DIM_SEQ))...>;
675 ref.m_pointer - RAJA::sum<camp::idx_t>((tile_origin.m_begin[DIM_SEQ] *
676 ref.m_stride[DIM_SEQ])...),
681 template<
typename REF_TYPE,
typename TILE_TYPE>
684 TILE_TYPE
const&
tile) ->
688 camp::make_idx_seq_t<TILE_TYPE::s_num_dims>>::merge_type
691 camp::make_idx_seq_t<TILE_TYPE::s_num_dims>>::merge(ref,
699 template<
typename REF_TYPE,
typename TILE_TYPE>
702 TILE_TYPE
const& tile_origin) ->
706 camp::make_idx_seq_t<TILE_TYPE::s_num_dims>>::shift_type
710 camp::make_idx_seq_t<TILE_TYPE::s_num_dims>>::shift_origin(ref,
717 template<
typename INDEX_TYPE, TensorTileSize RTENSOR_SIZE, camp::
idx_t NUM_DIMS>
729 template<
typename INDEX_TYPE, TensorTileSize RTENSOR_SIZE, camp::
idx_t NUM_DIMS>
742 template<
typename INDEX_TYPE,
753 return reinterpret_cast<
760 template<
typename INDEX_TYPE,
771 return reinterpret_cast<
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_UNUSED_ARG(x)
Definition: macros.hpp:97
TensorTileSize
Definition: TensorRef.hpp:234
@ TENSOR_MULTIPLE
Definition: TensorRef.hpp:237
@ TENSOR_PARTIAL
Definition: TensorRef.hpp:235
@ TENSOR_FULL
Definition: TensorRef.hpp:236
RAJA_INLINE constexpr RAJA_HOST_DEVICE TensorTile< INDEX_TYPE, TENSOR_FULL, NUM_DIMS > & make_tensor_tile_full(TensorTile< INDEX_TYPE, RTENSOR_SIZE, NUM_DIMS > &tile)
Definition: TensorRef.hpp:721
RAJA_INLINE constexpr RAJA_HOST_DEVICE auto shift_tile_origin(REF_TYPE const &ref, TILE_TYPE const &tile_origin) -> typename MergeRefTile< REF_TYPE, TILE_TYPE, camp::make_idx_seq_t< TILE_TYPE::s_num_dims >>::shift_type
Definition: TensorRef.hpp:700
RAJA_INLINE constexpr RAJA_HOST_DEVICE auto merge_ref_tile(REF_TYPE const &ref, TILE_TYPE const &tile) -> typename MergeRefTile< REF_TYPE, TILE_TYPE, camp::make_idx_seq_t< TILE_TYPE::s_num_dims >>::merge_type
Definition: TensorRef.hpp:682
RAJA_INLINE constexpr RAJA_HOST_DEVICE TensorTile< INDEX_TYPE, TENSOR_PARTIAL, NUM_DIMS > & make_tensor_tile_partial(TensorTile< INDEX_TYPE, RTENSOR_SIZE, NUM_DIMS > &tile)
Definition: TensorRef.hpp:733
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE RAJA_INLINE void tile(CONTEXT const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch_core.hpp:589
typename PrependStaticIndexArray< INDEX_TYPE, HEAD+DELTA, typename Orig::Tail >::Type Type
Definition: TensorRef.hpp:187
typename PrependStaticIndexArray< INDEX_TYPE, HEAD+DELTA, typename Orig::Tail >::Seq Seq
Definition: TensorRef.hpp:190
typename PrependStaticIndexArray< INDEX_TYPE, HEAD, AddTail >::Seq Seq
Definition: TensorRef.hpp:170
typename PrependStaticIndexArray< INDEX_TYPE, HEAD, AddTail >::Type Type
Definition: TensorRef.hpp:169
typename AddStaticIndexArray< INDEX_TYPE, IDX - 1, DELTA, typename Orig::Tail >::Type AddTail
Definition: TensorRef.hpp:167
Definition: TensorRef.hpp:41
typename REF_TYPE::index_type ref_index_type
Definition: TensorRef.hpp:554
typename TILE_TYPE::index_type tile_index_type
Definition: TensorRef.hpp:557
RAJA_INLINE static constexpr RAJA_HOST_DEVICE merge_type merge(REF_TYPE const &ref, TILE_TYPE const &tile)
Definition: TensorRef.hpp:569
RAJA_INLINE static constexpr RAJA_HOST_DEVICE shift_type shift_origin(REF_TYPE const &ref, TILE_TYPE const &tile_origin)
Definition: TensorRef.hpp:578
typename REF_TYPE::pointer_type pointer_type
Definition: TensorRef.hpp:553
RAJA::internal::expt::MergeRefTile< StaticTensorRef< POINTER_TYPE, INDEX_TYPE1, RTENSOR_SIZE, STRIDE, camp::int_seq< INDEX_TYPE1, BEGIN1... >, camp::int_seq< INDEX_TYPE1, SIZE1... >, STRIDE_ONE_DIM >, StaticTensorTile< INDEX_TYPE2, TENSOR_SIZE, BEGIN2, SIZE2 >, camp::idx_seq< DIM_SEQ... > >::merge RAJA_INLINE static constexpr RAJA_HOST_DEVICE merge_type merge(ref_type const &ref, tile_type const &tile)
Definition: TensorRef.hpp:663
RAJA::internal::expt::MergeRefTile< StaticTensorRef< POINTER_TYPE, INDEX_TYPE1, RTENSOR_SIZE, STRIDE, camp::int_seq< INDEX_TYPE1, BEGIN1... >, camp::int_seq< INDEX_TYPE1, SIZE1... >, STRIDE_ONE_DIM >, StaticTensorTile< INDEX_TYPE2, TENSOR_SIZE, BEGIN2, SIZE2 >, camp::idx_seq< DIM_SEQ... > >::new_stride_seq camp::int_seq< INDEX_TYPE2, INDEX_TYPE2(ref_stride_type::value_at(DIM_SEQ))... > new_stride_seq
Definition: TensorRef.hpp:631
RAJA::internal::expt::MergeRefTile< StaticTensorRef< POINTER_TYPE, INDEX_TYPE1, RTENSOR_SIZE, STRIDE, camp::int_seq< INDEX_TYPE1, BEGIN1... >, camp::int_seq< INDEX_TYPE1, SIZE1... >, STRIDE_ONE_DIM >, StaticTensorTile< INDEX_TYPE2, TENSOR_SIZE, BEGIN2, SIZE2 >, camp::idx_seq< DIM_SEQ... > >::ref_stride_type typename ref_type ::stride_type ref_stride_type
Definition: TensorRef.hpp:627
RAJA::internal::expt::MergeRefTile< StaticTensorRef< POINTER_TYPE, INDEX_TYPE1, RTENSOR_SIZE, STRIDE, camp::int_seq< INDEX_TYPE1, BEGIN1... >, camp::int_seq< INDEX_TYPE1, SIZE1... >, STRIDE_ONE_DIM >, StaticTensorTile< INDEX_TYPE2, TENSOR_SIZE, BEGIN2, SIZE2 >, camp::idx_seq< DIM_SEQ... > >::shift_origin RAJA_INLINE static constexpr RAJA_HOST_DEVICE shift_type shift_origin(ref_type const &ref, tile_type const &tile_origin)
Definition: TensorRef.hpp:671
RAJA::internal::expt::MergeRefTile< StaticTensorRef< POINTER_TYPE, INDEX_TYPE1, RTENSOR_SIZE, STRIDE, camp::int_seq< INDEX_TYPE1, BEGIN1... >, camp::int_seq< INDEX_TYPE1, SIZE1... >, STRIDE_ONE_DIM >, StaticTensorTile< INDEX_TYPE2, TENSOR_SIZE, BEGIN2, SIZE2 >, camp::idx_seq< DIM_SEQ... > >::shift_size_seq camp::int_seq< INDEX_TYPE2, INDEX_TYPE2(SIZE1)... > shift_size_seq
Definition: TensorRef.hpp:634
RAJA::internal::expt::MergeRefTile< StaticTensorRef< POINTER_TYPE, INDEX_TYPE1, RTENSOR_SIZE, STRIDE, camp::int_seq< INDEX_TYPE1, BEGIN1... >, camp::int_seq< INDEX_TYPE1, SIZE1... >, STRIDE_ONE_DIM >, StaticTensorTile< INDEX_TYPE2, TENSOR_SIZE, BEGIN2, SIZE2 >, camp::idx_seq< DIM_SEQ... > >::shift_begin_seq camp::int_seq< INDEX_TYPE2, INDEX_TYPE2(BEGIN1)... > shift_begin_seq
Definition: TensorRef.hpp:633
Definition: TensorRef.hpp:540
camp::int_seq< INDEX_TYPE, NEW_HEAD, ORIG_INTS... > Seq
Definition: TensorRef.hpp:149
Definition: TensorRef.hpp:38
typename PrependStaticIndexArray< INDEX_TYPE, VALUE, typename Orig::Tail >::Type Type
Definition: TensorRef.hpp:227
typename PrependStaticIndexArray< INDEX_TYPE, VALUE, typename Orig::Tail >::Seq Seq
Definition: TensorRef.hpp:230
typename PrependStaticIndexArray< INDEX_TYPE, HEAD, SetTail >::Seq Seq
Definition: TensorRef.hpp:211
typename SetStaticIndexArray< INDEX_TYPE, IDX - 1, VALUE, typename Orig::Tail >::Type SetTail
Definition: TensorRef.hpp:208
typename PrependStaticIndexArray< INDEX_TYPE, HEAD, SetTail >::Type Type
Definition: TensorRef.hpp:210
Definition: TensorRef.hpp:44
Definition: TensorRef.hpp:377
Definition: TensorRef.hpp:399
Definition: TensorRef.hpp:48
RAJA_HOST_DEVICE RAJA_INLINE void print_values() const
Definition: TensorRef.hpp:93
camp::int_seq< INDEX_TYPE, HEAD, TAIL... > seq_type
Definition: TensorRef.hpp:50
RAJA_HOST_DEVICE RAJA_INLINE void print() const
Definition: TensorRef.hpp:102
RAJA_HOST_DEVICE constexpr RAJA_INLINE INDEX_TYPE operator[](size_t index) const
Definition: TensorRef.hpp:78
RAJA_HOST_DEVICE static constexpr RAJA_INLINE INDEX_TYPE value_at(size_t index)
Definition: TensorRef.hpp:63
Tail tail
Definition: TensorRef.hpp:54
RAJA_HOST_DEVICE constexpr RAJA_INLINE INDEX_TYPE operator[](size_t) const
Definition: TensorRef.hpp:128
RAJA_HOST_DEVICE RAJA_INLINE void print_values() const
Definition: TensorRef.hpp:133
camp::int_seq< INDEX_TYPE > seq_type
Definition: TensorRef.hpp:114
RAJA_HOST_DEVICE static constexpr RAJA_INLINE INDEX_TYPE value_at(size_t)
Definition: TensorRef.hpp:123
RAJA_HOST_DEVICE RAJA_INLINE void print() const
Definition: TensorRef.hpp:138
Definition: TensorRef.hpp:35
pointer_type m_pointer
Definition: TensorRef.hpp:518
camp::int_seq< INDEX_TYPE, BeginInts... > begin_seq
Definition: TensorRef.hpp:497
RAJA_HOST_DEVICE RAJA_INLINE void print() const
Definition: TensorRef.hpp:525
INDEX_TYPE index_type
Definition: TensorRef.hpp:494
camp::int_seq< INDEX_TYPE, StrideInts... > stride_seq
Definition: TensorRef.hpp:496
POINTER_TYPE pointer_type
Definition: TensorRef.hpp:493
stride_type m_stride
Definition: TensorRef.hpp:519
camp::int_seq< INDEX_TYPE, SizeInts... > size_seq
Definition: TensorRef.hpp:498
tile_type m_tile
Definition: TensorRef.hpp:520
Definition: TensorRef.hpp:472
size_type m_size
Definition: TensorRef.hpp:338
INDEX_TYPE index_type
Definition: TensorRef.hpp:328
constexpr void copy(StaticTensorTile< INDEX_TYPE, S, begin_seq, size_seq > const RAJA_UNUSED_ARG(&c)) const
Definition: TensorRef.hpp:355
begin_type m_begin
Definition: TensorRef.hpp:337
RAJA_HOST_DEVICE RAJA_INLINE void print() const
Definition: TensorRef.hpp:362
constexpr nonstatic_self_type nonstatic() const
Definition: TensorRef.hpp:352
camp::int_seq< INDEX_TYPE, SizeInts... > size_seq
Definition: TensorRef.hpp:323
camp::int_seq< INDEX_TYPE, BeginInts... > begin_seq
Definition: TensorRef.hpp:322
Definition: TensorRef.hpp:309
Definition: TensorRef.hpp:426
index_type m_stride[NUM_DIMS]
Definition: TensorRef.hpp:442
static constexpr camp::idx_t s_num_dims
Definition: TensorRef.hpp:428
static constexpr TensorTileSize s_tensor_size
Definition: TensorRef.hpp:429
pointer_type m_pointer
Definition: TensorRef.hpp:441
POINTER_TYPE pointer_type
Definition: TensorRef.hpp:437
static constexpr camp::idx_t s_stride_one_dim
Definition: TensorRef.hpp:427
RAJA_HOST_DEVICE RAJA_INLINE void print() const
Definition: TensorRef.hpp:448
INDEX_TYPE index_type
Definition: TensorRef.hpp:438
tile_type m_tile
Definition: TensorRef.hpp:443
Definition: TensorRef.hpp:242
RAJA_HOST_DEVICE RAJA_INLINE void print() const
Definition: TensorRef.hpp:284
void copy(TensorTile< I, S, NUM_DIMS > const &c)
Definition: TensorRef.hpp:253
static constexpr TensorTileSize s_tensor_size
Definition: TensorRef.hpp:250
INDEX_TYPE index_type
Definition: TensorRef.hpp:245
index_type m_begin[NUM_DIMS]
Definition: TensorRef.hpp:246
RAJA_HOST_DEVICE RAJA_INLINE self_type operator-(TensorTile< INDEX_TYPE2, TENSOR_SIZE2, NUM_DIMS > const &sub) const
Definition: TensorRef.hpp:271
index_type m_size[NUM_DIMS]
Definition: TensorRef.hpp:247
static constexpr camp::idx_t s_num_dims
Definition: TensorRef.hpp:249
TensorTile< INDEX_TYPE, TENSOR_SIZE, NUM_DIMS > self_type
Definition: TensorRef.hpp:243