21 #ifndef RAJA_policy_cuda_kernel_InitLocalMem_HPP
22 #define RAJA_policy_cuda_kernel_InitLocalMem_HPP
24 #include "RAJA/config.hpp"
27 #include <type_traits>
38 struct cuda_thread_mem;
39 struct cuda_shared_mem;
45 template<
typename Data,
46 camp::idx_t... Indices,
47 typename... EnclosedStmts,
49 struct CudaStatementExecutor<Data,
50 statement::InitLocalMem<RAJA::cuda_shared_mem,
51 camp::idx_seq<Indices...>,
60 template<camp::
idx_t Pos>
63 using varType =
typename camp::tuple_element_t<
64 Pos,
typename camp::decay<Data>::param_tuple_t>::value_type;
65 const camp::idx_t NumElem = camp::tuple_element_t<
66 Pos,
typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
68 __shared__ varType Array[NumElem];
69 camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
71 enclosed_stmts_t::exec(data, thread_active);
76 template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
79 using varType =
typename camp::tuple_element_t<
80 Pos,
typename camp::decay<Data>::param_tuple_t>::value_type;
81 const camp::idx_t NumElem = camp::tuple_element_t<
82 Pos,
typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
84 __shared__ varType Array[NumElem];
85 camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
86 initMem<other0, others...>(data, thread_active);
90 template<camp::
idx_t Pos>
94 camp::get<Pos>(data.param_tuple).set_data(
nullptr);
98 template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
102 camp::get<Pos>(data.param_tuple).set_data(
nullptr);
103 setPtrToNull<other0, others...>(data);
110 initMem<Indices...>(data, thread_active);
113 setPtrToNull<Indices...>(data);
118 return enclosed_stmts_t::calculateDimensions(data);
123 template<
typename Data,
124 camp::idx_t... Indices,
125 typename... EnclosedStmts,
127 struct CudaStatementExecutor<Data,
128 statement::InitLocalMem<RAJA::cuda_thread_mem,
129 camp::idx_seq<Indices...>,
138 template<camp::
idx_t Pos>
141 using varType =
typename camp::tuple_element_t<
142 Pos,
typename camp::decay<Data>::param_tuple_t>::value_type;
143 const camp::idx_t NumElem = camp::tuple_element_t<
144 Pos,
typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
146 varType Array[NumElem];
147 camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
149 enclosed_stmts_t::exec(data, thread_active);
154 template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
157 using varType =
typename camp::tuple_element_t<
158 Pos,
typename camp::decay<Data>::param_tuple_t>::value_type;
159 const camp::idx_t NumElem = camp::tuple_element_t<
160 Pos,
typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
162 varType Array[NumElem];
163 camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
164 initMem<other0, others...>(data, thread_active);
168 template<camp::
idx_t Pos>
172 camp::get<Pos>(data.param_tuple).set_data(
nullptr);
176 template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
180 camp::get<Pos>(data.param_tuple).set_data(
nullptr);
181 setPtrToNull<other0, others...>(data);
188 initMem<Indices...>(data, thread_active);
191 setPtrToNull<Indices...>(data);
196 return enclosed_stmts_t::calculateDimensions(data);
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
camp::list< Stmts... > StatementList
Definition: StatementList.hpp:41
Definition: AlignedRangeIndexSetBuilders.cpp:35
Header file for shared memory window.
RAJA header file containing constructs used to run kernel traversals on GPU with CUDA.
CudaStatementListExecutor< Data, stmt_list_t, Types > enclosed_stmts_t
Definition: InitLocalMem.hpp:57
StatementList< EnclosedStmts... > stmt_list_t
Definition: InitLocalMem.hpp:56
static RAJA_DEVICE void exec(Data &data, bool thread_active)
Definition: InitLocalMem.hpp:106
static RAJA_DEVICE void setPtrToNull(Data &data)
Definition: InitLocalMem.hpp:91
static RAJA_DEVICE void initMem(Data &data, bool thread_active)
Definition: InitLocalMem.hpp:77
static RAJA_DEVICE void initMem(Data &data, bool thread_active)
Definition: InitLocalMem.hpp:61
static LaunchDims calculateDimensions(Data const &data)
Definition: InitLocalMem.hpp:116
static RAJA_DEVICE void setPtrToNull(Data &data)
Definition: InitLocalMem.hpp:99
static LaunchDims calculateDimensions(Data const &data)
Definition: InitLocalMem.hpp:194
static RAJA_DEVICE void exec(Data &data, bool thread_active)
Definition: InitLocalMem.hpp:184
CudaStatementListExecutor< Data, stmt_list_t, Types > enclosed_stmts_t
Definition: InitLocalMem.hpp:135
static RAJA_DEVICE void setPtrToNull(Data &data)
Definition: InitLocalMem.hpp:169
static RAJA_DEVICE void initMem(Data &data, bool thread_active)
Definition: InitLocalMem.hpp:155
static RAJA_DEVICE void initMem(Data &data, bool thread_active)
Definition: InitLocalMem.hpp:139
static RAJA_DEVICE void setPtrToNull(Data &data)
Definition: InitLocalMem.hpp:177
StatementList< EnclosedStmts... > stmt_list_t
Definition: InitLocalMem.hpp:134
Header file for RAJA type definitions.