10 #ifndef FORALL_PARAM_HPP
11 #define FORALL_PARAM_HPP
17 #include "camp/camp.hpp"
18 #include "camp/concepts.hpp"
19 #include "camp/tuple.hpp"
33 template<
typename...
Params>
36 return camp::get_refs_to_elements_by_type_trait<is_instance_of_Reducer>(
40 template<
typename ExecPol,
45 const camp::idx_seq<Seq...>&,
49 std::forward<Args>(
args)...),
53 template<
typename ExecPol,
typename...
Params,
typename... Args>
57 using ParamTupleType = decltype(params);
58 resolve_params_helper<ExecPol>(
59 params, camp::make_idx_seq_t<camp::tuple_size<ParamTupleType>::value>(),
60 std::forward<Args>(
args)...);
63 template<
typename ExecPol,
68 const camp::idx_seq<Seq...>&,
71 (
param_init(ExecPol {}, camp::get<Seq>(params_tuple),
72 std::forward<Args>(
args)...),
76 template<
typename ExecPol,
typename...
Params,
typename... Args>
80 using ParamTupleType = decltype(params);
81 init_params_helper<ExecPol>(
82 params, camp::make_idx_seq_t<camp::tuple_size<ParamTupleType>::value>(),
83 std::forward<Args>(
args)...);
86 template<
typename ExecPol,
typename ParamTuple, camp::idx_t... Seq>
88 ParamTuple& params_tuple)
90 (
param_combine(ExecPol {}, camp::get<Seq>(params_tuple)), ...);
93 template<
typename EXEC_POL,
typename T>
94 camp::concepts::enable_if<
95 concepts::negate<is_instance_of_Reducer<camp::decay<T>>>,
96 concepts::negate<std::is_same<T, RAJA::detail::Name>>>
100 template<
typename ExecPol,
typename ParamTuple, camp::idx_t... Seq>
102 ParamTuple& params_tuple,
103 const ParamTuple& params_tuple_in)
106 camp::get<Seq>(params_tuple_in)),
110 template<
typename ExecPol,
typename...
Params>
114 using ParamTupleType = camp::decay<decltype(params)>;
115 combine_params_helper<ExecPol>(
116 camp::make_idx_seq_t<camp::tuple_size<ParamTupleType>::value>(), params);
119 template<
typename ExecPol,
typename...
Params>
121 camp::tuple<Params...>& params_tuple,
122 const camp::tuple<Params...>& params_tuple_in)
124 using ParamTupleType = camp::decay<decltype(params_tuple)>;
125 combine_params_helper<ExecPol>(
126 camp::make_idx_seq_t<camp::tuple_size<ParamTupleType>::value>(),
127 params_tuple, params_tuple_in);
136 struct ParamMultiplexer;
138 template<
typename...
Params>
152 template<
typename EXEC_POL, camp::idx_t... Seq,
typename... Args>
153 static constexpr
void parampack_init(EXEC_POL
const& pol,
154 camp::idx_seq<Seq...>,
159 std::forward<Args>(
args)...),
164 template<
typename EXEC_POL, camp::idx_t... Seq>
167 camp::idx_seq<Seq...>,
176 template<
typename EXEC_POL, camp::idx_t... Seq>
179 camp::idx_seq<Seq...>,
182 (
param_combine(pol, camp::get<Seq>(f_params.param_tup)), ...);
186 template<
typename EXEC_POL, camp::idx_t... Seq,
typename... Args>
187 static constexpr
void parampack_resolve(EXEC_POL
const& pol,
188 camp::idx_seq<Seq...>,
193 std::forward<Args>(
args)...),
198 template<
typename null_t = camp::nil>
199 static constexpr
auto LAMBDA_ARG_TUP_T()
201 return camp::tuple<> {};
204 template<
typename null_t = camp::nil,
typename First>
205 static constexpr
auto LAMBDA_ARG_TUP_T()
207 return typename First::ARG_TUP_T();
210 template<
typename null_t = camp::nil,
214 static constexpr
auto LAMBDA_ARG_TUP_T()
216 return camp::tuple_cat_pair(
typename First::ARG_TUP_T(),
217 LAMBDA_ARG_TUP_T<camp::nil, Second, Rest...>());
220 using lambda_arg_tuple_t = decltype(LAMBDA_ARG_TUP_T<camp::nil, Params...>());
225 return camp::make_tuple();
233 template<camp::
idx_t N>
236 return camp::tuple_cat_pair(
237 camp::get<param_tup_sz - N>(
param_tup).get_lambda_arg_tup(),
238 LAMBDA_ARG_TUP_V(camp::num<N - 1>()));
246 return LAMBDA_ARG_TUP_V(camp::num<
sizeof...(
Params)>());
250 camp::make_idx_seq_t<camp::tuple_size<lambda_arg_tuple_t>::value>;
252 template<
typename... Ts>
265 template<
typename EXEC_POL,
273 constexpr
bool has_reducers =
275 if constexpr (has_reducers)
277 FP::parampack_init(pol,
typename FP::params_seq(), f_params,
278 std::forward<Args>(
args)...);
282 template<
typename EXEC_POL,
291 constexpr
bool has_reducers =
293 if constexpr (has_reducers)
295 FP::parampack_combine(pol,
typename FP::params_seq(), f_params,
296 std::forward<Args>(
args)...);
300 template<
typename EXEC_POL,
308 constexpr
bool has_reducers =
310 if constexpr (has_reducers)
312 FP::parampack_resolve(pol,
typename FP::params_seq(), f_params,
313 std::forward<Args>(
args)...);
327 RAJA_INLINE
static auto get_empty_forall_param_pack()
329 static ForallParamPack<> p;
342 template<
typename Base,
typename... Ts>
347 template<
typename... Ts>
351 camp::decay<Ts>...>::value,
352 "Forall optional arguments do not derive ForallParamBase. "
353 "Please see Reducer, ReducerLoc and Name for examples.");
360 template<camp::idx_t... Seq,
typename TupleType>
363 return camp::forward_as_tuple(
364 camp::get<Seq>(std::forward<TupleType>(tuple))...);
367 template<
typename... Ts>
376 template<
typename... Args>
382 camp::forward_as_tuple(std::forward<Args>(
args)...));
395 template<
typename... Args>
399 camp::forward_as_tuple(std::forward<Args>(
args)...));
410 template<std::size_t name_idx,
412 std::enable_if_t<(name_idx <
sizeof...(Args))>* =
nullptr>
415 return std::string(camp::get<name_idx>(std::move(tuple_args)).name);
418 template<std::size_t name_idx,
420 std::enable_if_t<(name_idx >=
sizeof...(Args))>* =
nullptr>
424 return std::string();
427 template<
typename... Args, std::size_t... Idx>
429 camp::tuple<Args...>&& tuple_args,
433 constexpr std::size_t name_idx =
std::min(
438 return get_kernel_name_string<name_idx>(std::move(tuple_args));
441 template<
typename... Args>
445 camp::forward_as_tuple(std::forward<Args>(
args)...),
446 std::make_index_sequence<
sizeof...(Args)> {});
468 template<
class R,
class C,
class First,
class... Rest>
474 template<
class R,
class C,
class First,
class... Rest>
488 template<
typename... Ts>
491 return camp::list<camp::decay<typename std::remove_pointer<Ts>::type>...> {};
494 template<
typename... Ts>
497 return camp::list<typename std::add_lvalue_reference<Ts>::type...> {};
500 template<
typename... Ts>
503 return camp::list<Ts...> {};
507 template<
typename F,
typename... Args>
509 : std::is_constructible<
510 std::function<void(Args...)>,
511 std::reference_wrapper<typename std::remove_reference<F>::type>>
517 template<
class F,
class =
void>
535 template<
typename LAMBDA,
typename... EXPECTED_ARGS>
536 constexpr concepts::enable_if<concepts::negate<has_empty_op<LAMBDA>>>
540 template<
typename LAMBDA,
typename... EXPECTED_ARGS>
543 const camp::list<EXPECTED_ARGS...>&)
545 #if !defined(RAJA_ENABLE_HIP)
548 EXPECTED_ARGS...>::value,
549 "LAMBDA Not invocable w/ EXPECTED_ARGS. Ordering and types must match "
550 "between RAJA::expt::Reduce() and ValOp arguments.");
556 template<
typename Lambda,
typename ForallParams>
577 template<camp::
idx_t Idx,
typename FP>
579 -> decltype(*camp::get<Idx>(fpp.lambda_args()))
581 return (*camp::get<Idx>(fpp.lambda_args()));
584 CAMP_SUPPRESS_HD_WARN
585 template<
typename Fn, camp::idx_t... Sequence,
typename Params,
typename... Ts>
588 camp::idx_seq<Sequence...>,
591 return f(std::forward<Ts...>(extra...),
592 (get_lambda_args<Sequence>(params))...);
597 template<
typename Params,
typename Fn,
typename... Ts>
602 using FPType = camp::decay<Params>;
603 constexpr
bool has_reducers =
605 if constexpr (has_reducers)
608 camp::forward<Params>(params), camp::forward<Fn>(f),
609 typename camp::decay<Params>::lambda_arg_seq(),
610 camp::forward<Ts...>(extra)...);
614 return f(camp::forward<Ts...>(extra)...);
Header file for RAJA CombingAdapter.
Header containing helper type traits for work with Reducers.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_UNUSED_ARG(x)
Definition: macros.hpp:97
Args args
Definition: WorkRunner.hpp:212
CAMP_SUPPRESS_HD_WARN constexpr RAJA_HOST_DEVICE auto invoke_with_order(Params &¶ms, Fn &&f, camp::idx_seq< Sequence... >, Ts &&... extra)
Definition: forall.hpp:586
constexpr auto list_add_lvalue_ref(const camp::list< Ts... > &)
Definition: forall.hpp:495
camp::concepts::enable_if< std::is_same< EXEC_POL, RAJA::seq_exec > > param_init(EXEC_POL const &, RAJA::detail::Name &)
Definition: kernel_name.hpp:24
constexpr RAJA_HOST_DEVICE auto get_lambda_args(FP &fpp) -> decltype(*camp::get< Idx >(fpp.lambda_args()))
Definition: forall.hpp:578
lambda_traits< T >::arg_type * lambda_arg_helper(T)
std::is_same< bool_pack< bs..., true >, bool_pack< true, bs... > > all_true
Definition: forall.hpp:340
RAJA_HOST_DEVICE void combine_params(camp::tuple< Params... > ¶ms_tuple)
Definition: forall.hpp:111
void init_params_helper(ParamTuple ¶ms_tuple, const camp::idx_seq< Seq... > &, Args &&... args)
Definition: forall.hpp:67
constexpr RAJA_HOST_DEVICE auto filter_reducers(camp::tuple< Params... > ¶ms)
Definition: forall.hpp:34
constexpr auto strip_last_elem(camp::tuple< Ts... > &&tuple)
Definition: forall.hpp:368
camp::concepts::enable_if< concepts::negate< is_instance_of_Reducer< camp::decay< T > > >, concepts::negate< std::is_same< T, RAJA::detail::Name > > > param_combine(EXEC_POL const &, T &, const T &)
Definition: forall.hpp:97
void resolve_params_helper(ParamTuple ¶ms_tuple, const camp::idx_seq< Seq... > &, Args &&... args)
Definition: forall.hpp:44
RAJA_HOST_DEVICE void combine_params_helper(const camp::idx_seq< Seq... > &, ParamTuple ¶ms_tuple)
Definition: forall.hpp:87
void void_t
Definition: forall.hpp:515
void resolve_params(camp::tuple< Params... > ¶ms_tuple, Args &&... args)
Definition: forall.hpp:54
constexpr auto tuple_from_seq(const camp::idx_seq< Seq... > &, TupleType &&tuple)
Definition: forall.hpp:361
void init_params(camp::tuple< Params... > ¶ms_tuple, Args &&... args)
Definition: forall.hpp:77
constexpr auto list_remove_pointer(const camp::list< Ts... > &)
Definition: forall.hpp:489
constexpr concepts::enable_if< concepts::negate< has_empty_op< LAMBDA > > > check_invocable(LAMBDA &&, const camp::list< EXPECTED_ARGS... > &)
Definition: forall.hpp:537
all_true< std::is_convertible< Ts, Base >::value... > check_types_derive_base
Definition: forall.hpp:344
constexpr auto tuple_to_list(const camp::tuple< Ts... > &)
Definition: forall.hpp:501
camp::concepts::enable_if< std::is_same< EXEC_POL, RAJA::seq_exec > > param_resolve(EXEC_POL const &, RAJA::detail::Name &)
Definition: kernel_name.hpp:40
std::string get_kernel_name(Args &&... args)
Definition: forall.hpp:442
constexpr void check_forall_optional_args(Lambda &&l, ForallParams &fpp)
Definition: forall.hpp:557
constexpr auto make_forall_param_pack_from_tuple(camp::tuple< Ts... > &&tuple)
Definition: forall.hpp:348
constexpr RAJA_HOST_DEVICE auto invoke_body(Params &¶ms, Fn &&f, Ts &&... extra)
Definition: forall.hpp:598
constexpr auto && get_lambda(Args &&... args)
Definition: forall.hpp:396
std::string get_kernel_name_helper(camp::tuple< Args... > &&tuple_args, std::index_sequence< Idx... > RAJA_UNUSED_ARG(i_seq))
Definition: forall.hpp:428
std::string get_kernel_name_string(camp::tuple< Args... > &&tuple_args)
Definition: forall.hpp:413
constexpr auto make_forall_param_pack(Args &&... args)
Definition: forall.hpp:377
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result min(Args... args)
Definition: foldl.hpp:161
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
camp::list< internal::LambdaArg< internal::lambda_arg_param_t, args >... > Params
Definition: Lambda.hpp:95
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
Definition: ListSegment.hpp:416
Definition: kernel_name.hpp:21
Definition: forall.hpp:140
Base param_tup
Definition: forall.hpp:145
ForallParamPack()
Definition: forall.hpp:242
constexpr RAJA_HOST_DEVICE lambda_arg_tuple_t lambda_args()
Definition: forall.hpp:244
camp::tuple< Params... > Base
Definition: forall.hpp:144
static constexpr size_t param_tup_sz
Definition: forall.hpp:147
ForallParamPack(camp::tuple< Ts... > &&t)
Definition: forall.hpp:253
camp::make_idx_seq_t< camp::tuple_size< lambda_arg_tuple_t >::value > lambda_arg_seq
Definition: forall.hpp:250
camp::make_idx_seq_t< param_tup_sz > params_seq
Definition: forall.hpp:148
Definition: forall.hpp:264
static constexpr void parampack_resolve(EXEC_POL const &pol, ForallParamPack< Params... > &f_params, Args &&... args)
Definition: forall.hpp:304
static constexpr void parampack_init(EXEC_POL const &pol, ForallParamPack< Params... > &f_params, Args &&... args)
Definition: forall.hpp:269
static RAJA_HOST_DEVICE constexpr void parampack_combine(EXEC_POL const &pol, ForallParamPack< Params... > &f_params, Args &&... args)
Definition: forall.hpp:286
Definition: params_base.hpp:282
Definition: forall.hpp:338
Definition: forall.hpp:528
std::remove_pointer< decltype(lambda_arg_helper(&camp::decay< F >::operator()))>::type type
Definition: forall.hpp:530
Definition: forall.hpp:519
Definition: forall.hpp:512
First arg_type
Definition: forall.hpp:477
First arg_type
Definition: forall.hpp:471
Definition: forall.hpp:466
Definition: TypeTraits.hpp:67