22 #ifndef RAJA_policy_hip_kernel_TileTCount_HPP
23 #define RAJA_policy_hip_kernel_TileTCount_HPP
25 #include "RAJA/config.hpp"
27 #if defined(RAJA_ENABLE_HIP)
30 #include <type_traits>
32 #include "camp/camp.hpp"
33 #include "camp/concepts.hpp"
34 #include "camp/tuple.hpp"
53 template<
typename Data,
54 camp::idx_t ArgumentId,
56 camp::idx_t chunk_size,
59 typename... EnclosedStmts,
61 struct HipStatementExecutor<
63 statement::TileTCount<
66 RAJA::tile_fixed<chunk_size>,
68 hip_indexer<iteration_mapping::DirectUnchecked, sync, IndexMapper>,
71 :
public HipStatementExecutor<
75 RAJA::tile_fixed<chunk_size>,
76 RAJA::policy::hip::hip_indexer<iteration_mapping::DirectUnchecked,
83 using Base = HipStatementExecutor<
88 RAJA::policy::hip::hip_indexer<iteration_mapping::DirectUnchecked,
94 using typename Base::diff_t;
95 using typename Base::enclosed_stmts_t;
97 static inline RAJA_DEVICE void exec(Data& data,
bool thread_active)
100 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
102 using segment_t = camp::decay<decltype(segment)>;
105 const diff_t t = IndexMapper::template index<diff_t>();
106 const diff_t i = t *
static_cast<diff_t
>(chunk_size);
109 segment_t orig_segment = segment;
112 segment = orig_segment.slice(i,
static_cast<diff_t
>(chunk_size));
113 data.template assign_param<ParamId>(t);
116 enclosed_stmts_t::exec(data, thread_active);
119 segment = orig_segment;
129 template<
typename Data,
130 camp::idx_t ArgumentId,
132 camp::idx_t chunk_size,
133 typename IndexMapper,
135 typename... EnclosedStmts,
137 struct HipStatementExecutor<
139 statement::TileTCount<
142 RAJA::tile_fixed<chunk_size>,
144 hip_indexer<iteration_mapping::Direct, sync, IndexMapper>,
147 :
public HipStatementExecutor<
151 RAJA::tile_fixed<chunk_size>,
153 hip_indexer<iteration_mapping::Direct, sync, IndexMapper>,
158 using Base = HipStatementExecutor<
160 statement::Tile<ArgumentId,
162 RAJA::policy::hip::hip_indexer<iteration_mapping::Direct,
168 using typename Base::diff_t;
169 using typename Base::enclosed_stmts_t;
171 static inline RAJA_DEVICE void exec(Data& data,
bool thread_active)
174 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
176 using segment_t = camp::decay<decltype(segment)>;
179 const diff_t len = segment.end() - segment.begin();
180 const diff_t t = IndexMapper::template index<diff_t>();
181 const diff_t i = t *
static_cast<diff_t
>(chunk_size);
185 const bool have_work = (i < len);
188 segment_t orig_segment = segment;
191 segment = orig_segment.slice(i,
static_cast<diff_t
>(chunk_size));
192 data.template assign_param<ParamId>(t);
195 enclosed_stmts_t::exec(data, thread_active && have_work);
198 segment = orig_segment;
208 template<
typename Data,
209 camp::idx_t ArgumentId,
211 camp::idx_t chunk_size,
212 typename IndexMapper,
213 typename... EnclosedStmts,
215 struct HipStatementExecutor<
217 statement::TileTCount<
220 RAJA::tile_fixed<chunk_size>,
221 RAJA::policy::hip::hip_indexer<
222 iteration_mapping::StridedLoop<named_usage::unspecified>,
223 kernel_sync_requirement::sync,
227 :
public HipStatementExecutor<
231 RAJA::tile_fixed<chunk_size>,
232 RAJA::policy::hip::hip_indexer<
233 iteration_mapping::StridedLoop<named_usage::unspecified>,
234 kernel_sync_requirement::sync,
240 using Base = HipStatementExecutor<
245 RAJA::policy::hip::hip_indexer<
246 iteration_mapping::StridedLoop<named_usage::unspecified>,
252 using typename Base::diff_t;
253 using typename Base::enclosed_stmts_t;
255 static inline RAJA_DEVICE void exec(Data& data,
bool thread_active)
258 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
261 using segment_t = camp::decay<decltype(segment)>;
262 segment_t orig_segment = segment;
265 const diff_t len = segment.end() - segment.begin();
266 const diff_t t_init = IndexMapper::template index<diff_t>();
267 const diff_t i_init = t_init *
static_cast<diff_t
>(chunk_size);
268 const diff_t t_stride = IndexMapper::template size<diff_t>();
269 const diff_t i_stride = t_stride *
static_cast<diff_t
>(chunk_size);
273 for (diff_t ii = 0, t = t_init; ii < len; ii += i_stride, t += t_stride)
275 const diff_t i = ii + i_init;
279 const bool have_work = (i < len);
282 segment = orig_segment.slice(i,
static_cast<diff_t
>(chunk_size));
283 data.template assign_param<ParamId>(t);
286 enclosed_stmts_t::exec(data, thread_active && have_work);
290 segment = orig_segment;
300 template<
typename Data,
301 camp::idx_t ArgumentId,
303 camp::idx_t chunk_size,
304 typename IndexMapper,
305 typename... EnclosedStmts,
307 struct HipStatementExecutor<
309 statement::TileTCount<
312 RAJA::tile_fixed<chunk_size>,
313 RAJA::policy::hip::hip_indexer<
314 iteration_mapping::StridedLoop<named_usage::unspecified>,
315 kernel_sync_requirement::none,
319 :
public HipStatementExecutor<
323 RAJA::tile_fixed<chunk_size>,
324 RAJA::policy::hip::hip_indexer<
325 iteration_mapping::StridedLoop<named_usage::unspecified>,
326 kernel_sync_requirement::none,
332 using Base = HipStatementExecutor<
337 RAJA::policy::hip::hip_indexer<
338 iteration_mapping::StridedLoop<named_usage::unspecified>,
344 using typename Base::diff_t;
345 using typename Base::enclosed_stmts_t;
347 static inline RAJA_DEVICE void exec(Data& data,
bool thread_active)
350 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
353 using segment_t = camp::decay<decltype(segment)>;
354 segment_t orig_segment = segment;
357 const diff_t len = segment.end() - segment.begin();
358 const diff_t t_init = IndexMapper::template index<diff_t>();
359 const diff_t i_init = t_init *
static_cast<diff_t
>(chunk_size);
360 const diff_t t_stride = IndexMapper::template size<diff_t>();
361 const diff_t i_stride = t_stride *
static_cast<diff_t
>(chunk_size);
365 for (diff_t i = i_init, t = t_init; i < len; i += i_stride, t += t_stride)
369 segment = orig_segment.slice(i,
static_cast<diff_t
>(chunk_size));
370 data.template assign_param<ParamId>(t);
373 enclosed_stmts_t::exec(data, thread_active);
377 segment = orig_segment;
386 template<
typename Data,
387 camp::idx_t ArgumentId,
390 typename... EnclosedStmts,
392 struct HipStatementExecutor<
395 TileTCount<ArgumentId, ParamId, TPol, seq_exec, EnclosedStmts...>,
397 : HipStatementExecutor<
399 statement::TileTCount<
403 RAJA::policy::hip::hip_indexer<
404 iteration_mapping::StridedLoop<named_usage::unspecified>,
405 kernel_sync_requirement::none,
406 hip::IndexGlobal<named_dim::x,
407 named_usage::ignored,
408 named_usage::ignored>>,
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
kernel_sync_requirement
Definition: types.hpp:63
Header file for tile wrapper and iterator.
Header file for loop kernel internals.
! tag for a tiling loop
Definition: Tile.hpp:72
Header file for RAJA type definitions.