21 #ifndef RAJA_policy_sycl_kernel_ForICount_HPP
22 #define RAJA_policy_sycl_kernel_ForICount_HPP
24 #include "RAJA/config.hpp"
41 template<
typename Data,
42 camp::idx_t ArgumentId,
45 typename... EnclosedStmts,
47 struct SyclStatementExecutor<
49 statement::ForICount<ArgumentId,
51 RAJA::sycl_local_012_direct<ThreadDim>,
54 :
public SyclStatementExecutor<
56 statement::For<ArgumentId,
57 RAJA::sycl_local_012_direct<ThreadDim>,
62 using Base = SyclStatementExecutor<
65 RAJA::sycl_local_012_direct<ThreadDim>,
69 using typename Base::diff_t;
70 using typename Base::enclosed_stmts_t;
73 ::sycl::nd_item<3> item,
76 diff_t len = segment_length<ArgumentId>(data);
77 auto i = item.get_local_id(ThreadDim);
80 data.template assign_offset<ArgumentId>(i);
81 data.template assign_param<ParamId>(i);
84 enclosed_stmts_t::exec(data, item, thread_active && (i < len));
92 template<
typename Data,
93 camp::idx_t ArgumentId,
96 typename... EnclosedStmts,
98 struct SyclStatementExecutor<
100 statement::ForICount<ArgumentId,
102 RAJA::sycl_local_masked_direct<Mask>,
105 :
public SyclStatementExecutor<
107 statement::For<ArgumentId,
108 RAJA::sycl_local_masked_direct<Mask>,
114 SyclStatementExecutor<Data,
116 RAJA::sycl_local_masked_direct<Mask>,
120 using typename Base::diff_t;
128 SyclStatementListExecutor<Data, stmt_list_t, NewTypes>;
133 ::sycl::nd_item<3> item,
136 diff_t len = segment_length<ArgumentId>(data);
137 auto i0 = item.get_local_id(0);
138 diff_t i = mask_t::maskValue(i0);
141 data.template assign_offset<ArgumentId>(i);
142 data.template assign_param<ParamId>(i);
145 enclosed_stmts_t::exec(data, item, thread_active && (i < len));
153 template<
typename Data,
154 camp::idx_t ArgumentId,
157 typename... EnclosedStmts,
159 struct SyclStatementExecutor<
161 statement::ForICount<ArgumentId,
163 RAJA::sycl_local_masked_loop<Mask>,
166 :
public SyclStatementExecutor<
168 statement::For<ArgumentId,
169 RAJA::sycl_local_masked_loop<Mask>,
175 SyclStatementExecutor<Data,
177 RAJA::sycl_local_masked_loop<Mask>,
181 using typename Base::diff_t;
189 SyclStatementListExecutor<Data, stmt_list_t, NewTypes>;
194 ::sycl::nd_item<3> item,
198 diff_t len = segment_length<ArgumentId>(data);
199 auto i0 = item.get_local_id(0);
200 diff_t i_init = mask_t::maskValue(i0);
201 diff_t i_stride = (diff_t)mask_t::max_masked_size;
204 for (diff_t ii = 0; ii < len; ii += i_stride)
206 diff_t i = ii + i_init;
210 bool have_work = i < len;
213 data.template assign_offset<ArgumentId>(i);
214 data.template assign_param<ParamId>(i);
217 enclosed_stmts_t::exec(data, item, thread_active && have_work);
229 template<
typename Data,
230 camp::idx_t ArgumentId,
233 typename... EnclosedStmts,
235 struct SyclStatementExecutor<
237 statement::ForICount<ArgumentId,
239 RAJA::sycl_local_012_loop<ThreadDim>,
242 :
public SyclStatementExecutor<
244 statement::For<ArgumentId,
245 RAJA::sycl_local_012_loop<ThreadDim>,
251 SyclStatementExecutor<Data,
253 RAJA::sycl_local_012_loop<ThreadDim>,
257 using typename Base::diff_t;
258 using typename Base::enclosed_stmts_t;
261 ::sycl::nd_item<3> item,
265 diff_t len = segment_length<ArgumentId>(data);
266 auto i_init = item.get_local_id(ThreadDim);
267 auto i_stride = item.get_local_range(ThreadDim);
270 for (diff_t ii = 0; ii < len; ii += i_stride)
272 diff_t i = ii + i_init;
276 bool have_work = i < len;
279 data.template assign_offset<ArgumentId>(i);
280 data.template assign_param<ParamId>(i);
283 enclosed_stmts_t::exec(data, item, thread_active && have_work);
294 template<
typename Data,
295 camp::idx_t ArgumentId,
298 typename... EnclosedStmts,
300 struct SyclStatementExecutor<
302 statement::ForICount<ArgumentId,
304 RAJA::sycl_group_012_direct<BlockDim>,
307 :
public SyclStatementExecutor<
309 statement::For<ArgumentId,
310 RAJA::sycl_group_012_direct<BlockDim>,
315 using Base = SyclStatementExecutor<
318 RAJA::sycl_group_012_direct<BlockDim>,
322 using typename Base::diff_t;
323 using typename Base::enclosed_stmts_t;
326 ::sycl::nd_item<3> item,
330 diff_t len = segment_length<ArgumentId>(data);
331 auto i = item.get_group(BlockDim);
337 data.template assign_offset<ArgumentId>(i);
338 data.template assign_param<ParamId>(i);
341 enclosed_stmts_t::exec(data, item, thread_active);
353 template<
typename Data,
354 camp::idx_t ArgumentId,
357 typename... EnclosedStmts,
359 struct SyclStatementExecutor<
361 statement::ForICount<ArgumentId,
363 RAJA::sycl_group_012_loop<BlockDim>,
366 :
public SyclStatementExecutor<
368 statement::For<ArgumentId,
369 RAJA::sycl_group_012_loop<BlockDim>,
375 SyclStatementExecutor<Data,
377 RAJA::sycl_group_012_loop<BlockDim>,
381 using typename Base::diff_t;
382 using typename Base::enclosed_stmts_t;
385 ::sycl::nd_item<3> item,
389 diff_t len = segment_length<ArgumentId>(data);
390 auto i_init = item.get_group(BlockDim);
391 auto i_stride = item.get_group_range(BlockDim);
394 for (diff_t i = i_init; i < len; i += i_stride)
398 data.template assign_offset<ArgumentId>(i);
399 data.template assign_param<ParamId>(i);
402 enclosed_stmts_t::exec(data, item, thread_active);
414 template<
typename Data,
415 camp::idx_t ArgumentId,
417 typename... EnclosedStmts,
419 struct SyclStatementExecutor<
421 statement::ForICount<ArgumentId, ParamId, seq_exec, EnclosedStmts...>,
423 :
public SyclStatementExecutor<
425 statement::For<ArgumentId, seq_exec, EnclosedStmts...>,
429 using Base = SyclStatementExecutor<
438 ::sycl::nd_item<3> item,
441 diff_t len = segment_length<ArgumentId>(data);
443 for (
diff_t i = 0; i < len; ++i)
446 data.template assign_offset<ArgumentId>(i);
447 data.template assign_param<ParamId>(i);
450 enclosed_stmts_t::exec(data, item, thread_active);
#define RAJA_DEVICE
Definition: macros.hpp:66
setSegmentType< Types, Segment, camp::at_v< typename camp::decay< Data >::index_types_t, Segment > > setSegmentTypeFromData
Definition: LoopTypes.hpp:95
camp::list< Stmts... > StatementList
Definition: StatementList.hpp:41
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA header file containing constructs used to run kernel traversals on GPU with SYCL.
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:72
SyclStatementExecutor< Data, statement::For< ArgumentId, RAJA::sycl_local_012_direct< ThreadDim >, EnclosedStmts... >, Types > Base
Definition: ForICount.hpp:67
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:325
SyclStatementExecutor< Data, statement::For< ArgumentId, RAJA::sycl_group_012_direct< BlockDim >, EnclosedStmts... >, Types > Base
Definition: ForICount.hpp:320
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:260
SyclStatementExecutor< Data, statement::For< ArgumentId, RAJA::sycl_local_012_loop< ThreadDim >, EnclosedStmts... >, Types > Base
Definition: ForICount.hpp:255
SyclStatementExecutor< Data, statement::For< ArgumentId, RAJA::sycl_local_masked_loop< Mask >, EnclosedStmts... >, Types > Base
Definition: ForICount.hpp:179
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:193
StatementList< EnclosedStmts... > stmt_list_t
Definition: ForICount.hpp:183
SyclStatementListExecutor< Data, stmt_list_t, NewTypes > enclosed_stmts_t
Definition: ForICount.hpp:189
setSegmentTypeFromData< Types, ArgumentId, Data > NewTypes
Definition: ForICount.hpp:186
Mask mask_t
Definition: ForICount.hpp:191
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:384
SyclStatementExecutor< Data, statement::For< ArgumentId, RAJA::sycl_group_012_loop< BlockDim >, EnclosedStmts... >, Types > Base
Definition: ForICount.hpp:379
Mask mask_t
Definition: ForICount.hpp:130
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:132
SyclStatementListExecutor< Data, stmt_list_t, NewTypes > enclosed_stmts_t
Definition: ForICount.hpp:128
SyclStatementExecutor< Data, statement::For< ArgumentId, RAJA::sycl_local_masked_direct< Mask >, EnclosedStmts... >, Types > Base
Definition: ForICount.hpp:118
StatementList< EnclosedStmts... > stmt_list_t
Definition: ForICount.hpp:122
setSegmentTypeFromData< Types, ArgumentId, Data > NewTypes
Definition: ForICount.hpp:125
static RAJA_DEVICE void exec(Data &data, ::sycl::nd_item< 3 > item, bool thread_active)
Definition: ForICount.hpp:437
SyclStatementListExecutor< Data, stmt_list_t, NewTypes > enclosed_stmts_t
Definition: For.hpp:471
segment_diff_type< ArgumentId, Data > diff_t
Definition: For.hpp:473
Definition: policy.hpp:78