22 #ifndef RAJA_policy_sycl_kernel_TileTCount_HPP
23 #define RAJA_policy_sycl_kernel_TileTCount_HPP
25 #include "RAJA/config.hpp"
27 #if defined(RAJA_ENABLE_SYCL)
30 #include <type_traits>
32 #include "camp/camp.hpp"
33 #include "camp/concepts.hpp"
34 #include "camp/tuple.hpp"
52 template<
typename Data,
53 camp::idx_t ArgumentId,
56 typename... EnclosedStmts,
58 struct SyclStatementExecutor<
61 TileTCount<ArgumentId, ParamId, TPol, seq_exec, EnclosedStmts...>,
63 :
public SyclStatementExecutor<
65 statement::Tile<ArgumentId, TPol, seq_exec, EnclosedStmts...>,
69 using Base = SyclStatementExecutor<
71 statement::Tile<ArgumentId, TPol, seq_exec, EnclosedStmts...>,
74 using typename Base::diff_t;
75 using typename Base::enclosed_stmts_t;
78 ::sycl::nd_item<3> item,
82 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
85 using segment_t = camp::decay<decltype(segment)>;
86 segment_t orig_segment = segment;
88 diff_t chunk_size = TPol::chunk_size;
91 diff_t len = segment.end() - segment.begin();
94 for (diff_t i = 0, t = 0; i < len; i += chunk_size, ++t)
98 segment = orig_segment.slice(i, chunk_size);
99 data.template assign_param<ParamId>(t);
102 enclosed_stmts_t::exec(data, item, thread_active);
106 segment = orig_segment;
115 template<
typename Data,
116 camp::idx_t ArgumentId,
118 camp::idx_t chunk_size,
120 typename... EnclosedStmts,
122 struct SyclStatementExecutor<
124 statement::TileTCount<ArgumentId,
126 RAJA::tile_fixed<chunk_size>,
127 sycl_group_012_direct<BlockDim>,
130 :
public SyclStatementExecutor<
132 statement::Tile<ArgumentId,
133 RAJA::tile_fixed<chunk_size>,
134 sycl_group_012_direct<BlockDim>,
140 SyclStatementExecutor<Data,
141 statement::Tile<ArgumentId,
143 sycl_group_012_direct<BlockDim>,
147 using typename Base::diff_t;
148 using typename Base::enclosed_stmts_t;
151 ::sycl::nd_item<3> item,
155 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
157 using segment_t = camp::decay<decltype(segment)>;
160 diff_t len = segment.end() - segment.begin();
162 diff_t t = item.get_group(BlockDim);
163 diff_t i = t * chunk_size;
170 segment_t orig_segment = segment;
173 segment = orig_segment.slice(i, chunk_size);
174 data.template assign_param<ParamId>(t);
177 enclosed_stmts_t::exec(data, item, thread_active);
180 segment = orig_segment;
190 template<
typename Data,
191 camp::idx_t ArgumentId,
193 camp::idx_t chunk_size,
195 typename... EnclosedStmts,
197 struct SyclStatementExecutor<
199 statement::TileTCount<ArgumentId,
201 RAJA::tile_fixed<chunk_size>,
202 sycl_group_012_loop<BlockDim>,
205 :
public SyclStatementExecutor<
207 statement::Tile<ArgumentId,
208 RAJA::tile_fixed<chunk_size>,
209 sycl_group_012_loop<BlockDim>,
215 SyclStatementExecutor<Data,
216 statement::Tile<ArgumentId,
218 sycl_group_012_loop<BlockDim>,
222 using typename Base::diff_t;
223 using typename Base::enclosed_stmts_t;
226 ::sycl::nd_item<3> item,
230 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
233 using segment_t = camp::decay<decltype(segment)>;
234 segment_t orig_segment = segment;
237 diff_t len = segment.end() - segment.begin();
238 diff_t t_init = item.get_group(BlockDim);
239 diff_t i_init = t_init * chunk_size;
240 diff_t t_stride = item.get_group_range(BlockDim);
241 diff_t i_stride = t_stride * chunk_size;
244 for (diff_t i = i_init, t = t_init; i < len; i += i_stride, t += t_stride)
248 segment = orig_segment.slice(i, chunk_size);
249 data.template assign_param<ParamId>(t);
252 enclosed_stmts_t::exec(data, item, thread_active);
256 segment = orig_segment;
265 template<
typename Data,
266 camp::idx_t ArgumentId,
268 camp::idx_t chunk_size,
270 typename... EnclosedStmts,
272 struct SyclStatementExecutor<
274 statement::TileTCount<ArgumentId,
276 RAJA::tile_fixed<chunk_size>,
277 sycl_local_012_direct<ThreadDim>,
280 :
public SyclStatementExecutor<
282 statement::Tile<ArgumentId,
283 RAJA::tile_fixed<chunk_size>,
284 sycl_local_012_direct<ThreadDim>,
290 SyclStatementExecutor<Data,
291 statement::Tile<ArgumentId,
293 sycl_local_012_direct<ThreadDim>,
297 using typename Base::diff_t;
298 using typename Base::enclosed_stmts_t;
301 ::sycl::nd_item<3> item,
305 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
308 using segment_t = camp::decay<decltype(segment)>;
309 segment_t orig_segment = segment;
312 diff_t len = segment.end() - segment.begin();
314 diff_t t = item.get_local_id(ThreadDim);
315 diff_t i = t * chunk_size;
319 bool have_work = i < len;
322 diff_t slice_size = have_work ? chunk_size : 0;
323 segment = orig_segment.slice(i, slice_size);
324 data.template assign_param<ParamId>(t);
327 enclosed_stmts_t::exec(data, item, thread_active && have_work);
330 segment = orig_segment;
339 template<
typename Data,
340 camp::idx_t ArgumentId,
342 camp::idx_t chunk_size,
344 typename... EnclosedStmts,
346 struct SyclStatementExecutor<
348 statement::TileTCount<ArgumentId,
350 RAJA::tile_fixed<chunk_size>,
351 sycl_local_012_loop<ThreadDim>,
354 :
public SyclStatementExecutor<
356 statement::Tile<ArgumentId,
357 RAJA::tile_fixed<chunk_size>,
358 sycl_local_012_loop<ThreadDim>,
364 SyclStatementExecutor<Data,
365 statement::Tile<ArgumentId,
367 sycl_local_012_loop<ThreadDim>,
371 using typename Base::diff_t;
372 using typename Base::enclosed_stmts_t;
375 ::sycl::nd_item<3> item,
379 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
382 using segment_t = camp::decay<decltype(segment)>;
383 segment_t orig_segment = segment;
386 diff_t len = segment_length<ArgumentId>(data);
388 diff_t t_init = item.get_local_id(ThreadDim);
389 diff_t i_init = t_init * chunk_size;
391 diff_t t_stride = item.get_local_range(ThreadDim);
392 diff_t i_stride = t_stride * chunk_size;
395 for (diff_t ii = 0, t = t_init; ii < len; ii += i_stride, t += t_stride)
397 diff_t i = ii + i_init;
401 bool have_work = i < len;
404 diff_t slice_size = have_work ? chunk_size : 0;
405 segment = orig_segment.slice(i, slice_size);
406 data.template assign_param<ParamId>(t);
409 enclosed_stmts_t::exec(data, item, thread_active && have_work);
413 segment = orig_segment;
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
Header file for tile wrapper and iterator.
Header file for loop kernel internals.
! tag for a tiling loop
Definition: Tile.hpp:72
Header file for RAJA type definitions.