20 #ifndef RAJA_util_TypedViewBase_HPP
21 #define RAJA_util_TypedViewBase_HPP
23 #include <type_traits>
25 #include "RAJA/config.hpp"
29 #if defined(RAJA_ENABLE_VECTORIZATION)
43 template<camp::
idx_t,
typename T>
49 template<
typename IdxSeq,
typename T>
52 template<camp::idx_t...
Perm,
typename T>
55 using type = camp::list<typename IndexToType<Perm, T>::type...>;
58 template<
typename Perm>
65 template<
typename layout>
71 template<
typename IdxLin,
typename... DimTypes>
78 #if defined(RAJA_ENABLE_VECTORIZATION)
87 template<camp::
idx_t DIM,
typename ARGS,
typename IDX_SEQ>
88 struct GetTensorArgIdxExpanded;
90 template<camp::idx_t DIM,
typename... ARGS, camp::idx_t... IDX>
91 struct GetTensorArgIdxExpanded<DIM, camp::list<ARGS...>, camp::idx_seq<IDX...>>
94 static constexpr camp::idx_t value = RAJA::max<camp::idx_t>(
95 (internal::expt::isTensorIndex<ARGS>() &&
96 internal::expt::getTensorDim<ARGS>() == DIM
109 template<
typename... ARGS>
112 static constexpr camp::idx_t
value =
113 #if defined(RAJA_ENABLE_VECTORIZATION)
114 RAJA::sum<camp::idx_t>(
115 (internal::expt::isTensorIndex<ARGS>() ? 1 : 0)...);
121 #if defined(RAJA_ENABLE_VECTORIZATION)
125 template<camp::idx_t DIM,
typename... ARGS>
126 struct GetTensorArgIdx
128 static constexpr camp::idx_t value = detail::GetTensorArgIdxExpanded<
131 camp::make_idx_seq_t<
sizeof...(ARGS)>>::value;
134 template<camp::idx_t DIM,
typename... ARGS>
135 struct GetTensorArgIdx<DIM, camp::list<ARGS...>>
137 static constexpr camp::idx_t value = detail::GetTensorArgIdxExpanded<
140 camp::make_idx_seq_t<
sizeof...(ARGS)>>::value;
146 template<camp::idx_t DIM,
typename LAYOUT,
typename... ARGS>
147 RAJA_INLINE
RAJA_HOST_DEVICE static constexpr camp::idx_t get_tensor_args_begin(
148 LAYOUT
const& layout,
151 return RAJA::max<camp::idx_t>(
152 internal::expt::getTensorDim<ARGS>() == DIM
153 ? internal::expt::getTensorBegin<ARGS>(
154 args, layout.template get_dim_begin<
155 GetTensorArgIdx<DIM, ARGS...>::value>())
162 template<camp::idx_t DIM,
typename LAYOUT,
typename... ARGS>
163 RAJA_INLINE
RAJA_HOST_DEVICE static constexpr camp::idx_t get_tensor_args_size(
164 LAYOUT
const& layout,
167 return RAJA::max<camp::idx_t>(
168 internal::expt::getTensorDim<ARGS>() == DIM
169 ? internal::expt::getTensorSize<ARGS>(
170 args, layout.template get_dim_size<
171 GetTensorArgIdx<DIM, ARGS...>::value>())
188 template<
typename VecSeq,
190 typename ElementType,
191 typename PointerType,
199 template<
typename... Args,
200 typename ElementType,
201 typename PointerType,
214 LayoutType
const& layout,
215 PointerType
const& data,
223 #if defined(RAJA_ENABLE_VECTORIZATION)
227 template<camp::idx_t VecHead,
228 camp::idx_t... VecSeq,
230 typename ElementType,
231 typename PointerType,
234 struct ViewReturnHelper<camp::idx_seq<VecHead, VecSeq...>,
242 static constexpr camp::idx_t s_num_dims =
sizeof...(VecSeq) + 1;
251 static constexpr camp::idx_t s_stride_one_dim = RAJA::max<camp::idx_t>(
252 (GetTensorArgIdx<VecHead, Args...>::value == LayoutType::stride_one_dim
255 (GetTensorArgIdx<VecSeq, Args...>::value == LayoutType::stride_one_dim
260 using tensor_reg_type =
261 typename camp::at_v<camp::list<Args...>,
262 GetTensorArgIdx<0, Args...>::value>::tensor_type;
272 LayoutType
const& layout,
273 PointerType
const& data,
277 return return_type(ref_type {
280 layout(internal::expt::isTensorIndex<Args>()
285 {(LinIdx)layout.template get_dim_stride<
286 GetTensorArgIdx<VecHead, Args...>::value>(),
287 (LinIdx)layout.template get_dim_stride<
288 GetTensorArgIdx<VecSeq, Args...>::value>()...},
291 {(LinIdx)(get_tensor_args_begin<VecHead>(layout,
args...)),
292 (LinIdx)(get_tensor_args_begin<VecSeq>(layout,
args...))...},
295 {(LinIdx)get_tensor_args_size<VecHead>(layout,
args...),
296 (LinIdx)get_tensor_args_size<VecSeq>(layout,
args...)...}}});
303 template<camp::idx_t VecHead,
304 camp::idx_t... VecSeq,
305 typename... INDEX_TYPES,
306 typename ElementType,
307 typename PointerType,
311 LinIdx... StrideInts,
313 struct ViewReturnHelper<
314 camp::idx_seq<VecHead, VecSeq...>,
315 camp::list<RAJA::expt::StaticTensorIndex<INDEX_TYPES>...>,
320 camp::int_seq<LinIdx, RangeInts...>,
321 camp::int_seq<LinIdx, SizeInts...>,
322 camp::int_seq<LinIdx, StrideInts...>,
325 static constexpr camp::idx_t s_num_dims =
sizeof...(VecSeq) + 1;
327 using index_list = camp::list<RAJA::expt::StaticTensorIndex<INDEX_TYPES>...>;
329 using range_seq = camp::int_seq<LinIdx, RangeInts...>;
330 using size_seq = camp::int_seq<LinIdx, SizeInts...>;
331 using stride_seq = camp::int_seq<LinIdx, StrideInts...>;
332 using LayoutType = RAJA::detail::
333 StaticLayoutBase_impl<LinIdx, range_seq, size_seq, stride_seq, DIM_LIST>;
342 static constexpr camp::idx_t s_stride_one_dim = RAJA::max<camp::idx_t>(
343 (GetTensorArgIdx<VecHead, index_list>::value == LayoutType::stride_one_dim
346 (GetTensorArgIdx<VecSeq, index_list>::value == LayoutType::stride_one_dim
351 using new_begin_seq =
352 camp::int_seq<LinIdx,
353 (LinIdx)get_tensor_args_begin<VecHead>(
356 (LinIdx)get_tensor_args_begin<VecSeq>(
360 camp::int_seq<LinIdx,
361 (LinIdx)get_tensor_args_size<VecHead>(
364 (LinIdx)get_tensor_args_size<VecSeq>(
368 using new_begin_type = internal::expt::StaticIndexArray<new_begin_seq>;
369 using new_size_type = internal::expt::StaticIndexArray<new_size_seq>;
372 using tensor_reg_type =
373 typename camp::at_v<index_list,
374 GetTensorArgIdx<0, index_list>::value>::tensor_type;
376 internal::expt::StaticTensorRef<ElementType*,
384 internal::expt::ET::TensorLoadStore<tensor_reg_type, ref_type>;
387 LayoutType
const& layout,
388 PointerType
const& data,
392 return return_type(ref_type {
397 INDEX_TYPES>::base_type>()
402 typename ref_type::stride_type(),
404 {new_begin_type(), new_size_type()}});
420 template<
typename ElementType,
421 typename PointerType,
431 LayoutType>::return_type;
441 template<
typename ElementType,
444 typename PointerType,
452 PointerType
const& data,
457 camp::list<Args...>, ElementType, PointerType, LinIdx,
458 LayoutType>::make_return(layout, data,
args...);
474 template<
typename Expected,
typename Arg>
479 "Argument isn't compatible");
490 #if defined(RAJA_ENABLE_VECTORIZATION)
497 template<
typename Expected,
typename Arg,
typename VectorType, camp::
idx_t DIM>
498 struct MatchTypedViewArgHelper<Expected,
504 "Argument isn't compatible");
523 template<
typename Expected,
528 strip_index_type_t<Arg> LENGTH>
529 struct MatchTypedViewArgHelper<
531 RAJA::expt::StaticTensorIndex<
533 StaticTensorIndexInner<Arg, VectorType, DIM, BEGIN, LENGTH>>>
536 static_assert(std::is_convertible<strip_index_type_t<Arg>,
537 strip_index_type_t<Expected>>::value,
538 "Argument isn't compatible");
540 using arg_type = strip_index_type_t<Arg>;
544 StaticTensorIndexInner<arg_type, VectorType, DIM, BEGIN, LENGTH>>;
549 StaticTensorIndexInner<Arg, VectorType, DIM, BEGIN, LENGTH>>
559 template<
typename Expected,
typename Arg>
567 template<
typename ValueType,
typename Po
interType,
typename LayoutType>
578 typename std::remove_pointer<pointer_type>::type>::type>::type;
602 #if (defined(RAJA_ENABLE_CUDA) || defined(RAJA_ENABLE_CLANG_CUDA))
634 template<
typename... Args>
641 template<bool IsConstView = std::is_const<value_type>::value>
643 typename std::enable_if<IsConstView, NonConstView>::type
const& rhs)
668 template<camp::
idx_t DIM>
671 return m_layout.template get_dim_size<DIM>();
674 template<
typename... Args>
682 return view_make_return_value<value_type, linear_index_type>(
692 template<
typename... Args>
700 return view_make_return_value<value_type, linear_index_type>(
705 template<
size_t n_dims = layout_type::n_dims,
708 const std::array<IdxLin, n_dims>&
shift)
710 static_assert(n_dims == layout_type::n_dims,
711 "Dimension mismatch in ViewBase shift");
714 shift_layout.shift(
shift);
721 template<
typename ValueType,
722 typename PointerType,
727 template<
typename ValueType,
728 typename PointerType,
730 typename... IndexTypes>
734 camp::list<IndexTypes...>>
735 :
public ViewBase<ValueType, PointerType, LayoutType>
745 typename std::remove_pointer<pointer_type>::type>::type>::type;
751 camp::list<IndexTypes...>>;
755 camp::list<IndexTypes...>>;
761 camp::list<IndexTypes...>>;
763 static constexpr
size_t n_dims =
sizeof...(IndexTypes);
767 template<
typename... Args>
775 return view_make_return_value<value_type, linear_index_type>(
776 Base::m_layout, Base::m_data,
777 match_typed_view_arg<IndexTypes>(
args)...);
786 template<
typename... Args>
794 return view_make_return_value<value_type, linear_index_type>(
795 Base::m_layout, Base::m_data,
796 match_typed_view_arg<IndexTypes>(
args)...);
800 template<
size_t n_dims =
sizeof...(IndexTypes),
801 typename IdxLin = linear_index_type>
803 const std::array<IdxLin, n_dims>& shift)
805 static_assert(n_dims == layout_type::n_dims,
806 "Dimension mismatch in TypedViewBase shift");
809 shift_layout.shift(shift);
811 return ShiftedView(Base::get_data(), shift_layout);
RAJA header file defining Layout, a N-dimensional index calculator.
RAJA header file defining Layout, a N-dimensional index calculator with offset indices.
RAJA header file defining Layout, a N-dimensional index calculator with compile-time defined sizes an...
Definition: TensorIndex.hpp:45
RAJA_INLINE constexpr RAJA_HOST_DEVICE value_type size() const
Definition: TensorIndex.hpp:151
RAJA_HOST_DEVICE constexpr RAJA_INLINE view_return_type_t< value_type, pointer_type, linear_index_type, layout_type, Args... > operator()(Args... args) const
Definition: TypedViewBase.hpp:773
RAJA_HOST_DEVICE constexpr RAJA_INLINE view_return_type_t< value_type, pointer_type, linear_index_type, layout_type, Args... > operator[](Args... args) const
Definition: TypedViewBase.hpp:792
RAJA_HOST_DEVICE constexpr RAJA_INLINE ShiftedView shift(const std::array< IdxLin, n_dims > &shift)
Definition: TypedViewBase.hpp:802
Definition: TypedViewBase.hpp:725
Definition: TypedViewBase.hpp:569
RAJA_HOST_DEVICE constexpr RAJA_INLINE ViewBase(pointer_type data, Args... dim_sizes)
Definition: TypedViewBase.hpp:635
RAJA_HOST_DEVICE constexpr RAJA_INLINE ViewBase(pointer_type data, layout_type &&layout)
Definition: TypedViewBase.hpp:629
ViewBase< value_type, pointer_type, shifted_layout_type > ShiftedView
Definition: TypedViewBase.hpp:584
RAJA_HOST_DEVICE constexpr RAJA_INLINE linear_index_type size() const
Definition: TypedViewBase.hpp:663
LayoutType layout_type
Definition: TypedViewBase.hpp:574
RAJA_HOST_DEVICE constexpr RAJA_INLINE view_return_type_t< value_type, pointer_type, linear_index_type, layout_type, Args... > operator[](Args... args) const
Definition: TypedViewBase.hpp:698
pointer_type m_data
Definition: TypedViewBase.hpp:587
RAJA_HOST_DEVICE constexpr RAJA_INLINE ViewBase(typename std::enable_if< IsConstView, NonConstView >::type const &rhs)
Definition: TypedViewBase.hpp:642
typename std::remove_const< value_type >::type nc_value_type
Definition: TypedViewBase.hpp:576
constexpr RAJA_INLINE ViewBase & operator=(ViewBase const &)=default
constexpr ViewBase()=default
typename std::add_pointer< typename std::remove_const< typename std::remove_pointer< pointer_type >::type >::type >::type nc_pointer_type
Definition: TypedViewBase.hpp:578
RAJA_HOST_DEVICE constexpr RAJA_INLINE pointer_type const & get_data() const
Definition: TypedViewBase.hpp:653
PointerType pointer_type
Definition: TypedViewBase.hpp:573
RAJA_HOST_DEVICE constexpr RAJA_INLINE ShiftedView shift(const std::array< IdxLin, n_dims > &shift)
Definition: TypedViewBase.hpp:707
RAJA_HOST_DEVICE constexpr RAJA_INLINE layout_type const & get_layout() const
Definition: TypedViewBase.hpp:658
layout_type const m_layout
Definition: TypedViewBase.hpp:588
typename add_offset< layout_type >::type shifted_layout_type
Definition: TypedViewBase.hpp:583
constexpr RAJA_INLINE ViewBase(ViewBase &&)=default
RAJA_HOST_DEVICE constexpr RAJA_INLINE view_return_type_t< value_type, pointer_type, linear_index_type, layout_type, Args... > operator()(Args... args) const
Definition: TypedViewBase.hpp:680
ValueType value_type
Definition: TypedViewBase.hpp:572
RAJA_HOST_DEVICE constexpr RAJA_INLINE void set_data(PointerType data_ptr)
Definition: TypedViewBase.hpp:648
constexpr RAJA_INLINE ViewBase & operator=(ViewBase &&)=default
constexpr RAJA_INLINE ViewBase(ViewBase const &)=default
typename layout_type::IndexLinear linear_index_type
Definition: TypedViewBase.hpp:575
RAJA_HOST_DEVICE constexpr RAJA_INLINE linear_index_type get_dim_size() const
Definition: TypedViewBase.hpp:669
Definition: TensorLoadStore.hpp:75
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_UNUSED_ARG(x)
Definition: macros.hpp:97
Args args
Definition: WorkRunner.hpp:212
Definition: TypeTraits.hpp:30
RAJA_INLINE constexpr RAJA_HOST_DEVICE auto stripTensorIndexByValue(ARG const arg) -> typename TensorIndexTraits< ARG >::arg_type const
Definition: TensorIndexTraits.hpp:224
@ TENSOR_MULTIPLE
Definition: TensorRef.hpp:237
RAJA_INLINE constexpr RAJA_HOST_DEVICE bool isTensorIndex()
Definition: TensorIndexTraits.hpp:211
typename SequenceToType< Perm, RAJA::Index_type >::type getDefaultIndexTypes
Definition: TypedViewBase.hpp:60
typename detail::ViewReturnHelper< camp::make_idx_seq_t< count_num_tensor_args< Args... >::value >, camp::list< Args... >, ElementType, PointerType, LinIdx, LayoutType >::return_type view_return_type_t
Definition: TypedViewBase.hpp:431
RAJA_HOST_DEVICE constexpr RAJA_INLINE detail::MatchTypedViewArgHelper< Expected, Arg >::type match_typed_view_arg(Arg const &arg)
Definition: TypedViewBase.hpp:562
RAJA_INLINE constexpr RAJA_HOST_DEVICE view_return_type_t< ElementType, PointerType, LinIdx, LayoutType, Args... > view_make_return_value(LayoutType const &layout, PointerType const &data, Args const &... args)
Definition: TypedViewBase.hpp:451
Definition: AlignedRangeIndexSetBuilders.cpp:35
typename internal::StripIndexTypeT< FROM >::type strip_index_type_t
Strips a strongly typed index to its underlying type In the case of a non-strongly typed index,...
Definition: IndexValue.hpp:364
constexpr RAJA_HOST_DEVICE RAJA_INLINE std::enable_if< std::is_base_of< IndexValueBase, FROM >::value, typename FROM::value_type >::type stripIndexType(FROM const val)
Function that strips the strongly typed Index<> and returns its underlying value_type value.
Definition: IndexValue.hpp:323
camp::idx_seq< Ints... > Perm
Definition: PermutedLayout.hpp:101
RAJA header file defining atomic operations.
RAJA header file defining SIMD/SIMT register operations.
Definition: OffsetLayout.hpp:172
Definition: Layout.hpp:329
Definition: OffsetLayout.hpp:193
Definition: OffsetLayout.hpp:188
Definition: StaticLayout.hpp:48
Definition: TensorIndex.hpp:41
Definition: TypedViewBase.hpp:45
T type
Definition: TypedViewBase.hpp:46
camp::list< typename IndexToType< Perm, T >::type... > type
Definition: TypedViewBase.hpp:55
Definition: TypedViewBase.hpp:50
Definition: TypedViewBase.hpp:67
Definition: TypedViewBase.hpp:111
static constexpr camp::idx_t value
Definition: TypedViewBase.hpp:112
Definition: TypedViewBase.hpp:476
strip_index_type_t< Arg > type
Definition: TypedViewBase.hpp:481
static RAJA_HOST_DEVICE constexpr RAJA_INLINE type extract(Arg arg)
Definition: TypedViewBase.hpp:483
RAJA_INLINE static constexpr RAJA_HOST_DEVICE return_type make_return(LayoutType const &layout, PointerType const &data, Args const &... args)
Definition: TypedViewBase.hpp:213
ElementType & return_type
Definition: TypedViewBase.hpp:211
Definition: TypedViewBase.hpp:194
Definition: TensorRef.hpp:426