22 #ifndef RAJA_policy_sycl_kernel_Tile_HPP
23 #define RAJA_policy_sycl_kernel_Tile_HPP
25 #include "RAJA/config.hpp"
27 #if defined(RAJA_ENABLE_SYCL)
30 #include <type_traits>
32 #include "camp/camp.hpp"
33 #include "camp/concepts.hpp"
34 #include "camp/tuple.hpp"
52 template<
typename Data,
53 camp::idx_t ArgumentId,
55 typename... EnclosedStmts,
57 struct SyclStatementExecutor<
59 statement::Tile<ArgumentId, TPol, seq_exec, EnclosedStmts...>,
64 using enclosed_stmts_t = SyclStatementListExecutor<Data, stmt_list_t, Types>;
65 using diff_t = segment_diff_type<ArgumentId, Data>;
68 ::sycl::nd_item<3> item,
72 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
75 using segment_t = camp::decay<decltype(segment)>;
76 segment_t orig_segment = segment;
78 diff_t chunk_size = TPol::chunk_size;
81 diff_t len = segment.end() - segment.begin();
84 for (diff_t i = 0; i < len; i += chunk_size)
88 segment = orig_segment.slice(i, chunk_size);
91 enclosed_stmts_t::exec(data, item, thread_active);
95 segment = orig_segment;
98 static inline LaunchDims calculateDimensions(Data
const& data)
102 using data_t = camp::decay<Data>;
103 data_t private_data = data;
106 auto& segment = camp::get<ArgumentId>(private_data.segment_tuple);
109 segment = segment.slice(0, TPol::chunk_size);
112 LaunchDims enclosed_dims =
113 enclosed_stmts_t::calculateDimensions(private_data);
115 return enclosed_dims;
124 template<
typename Data,
125 camp::idx_t ArgumentId,
126 camp::idx_t chunk_size,
128 typename... EnclosedStmts,
130 struct SyclStatementExecutor<Data,
131 statement::Tile<ArgumentId,
132 RAJA::tile_fixed<chunk_size>,
133 sycl_group_012_direct<BlockDim>,
140 using enclosed_stmts_t = SyclStatementListExecutor<Data, stmt_list_t, Types>;
142 using diff_t = segment_diff_type<ArgumentId, Data>;
145 ::sycl::nd_item<3> item,
149 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
151 using segment_t = camp::decay<decltype(segment)>;
154 diff_t len = segment.end() - segment.begin();
157 item.get_group(BlockDim) *
165 segment_t orig_segment = segment;
168 segment = orig_segment.slice(i, chunk_size);
171 enclosed_stmts_t::exec(data, item, thread_active);
174 segment = orig_segment;
178 static inline LaunchDims calculateDimensions(Data
const& data)
182 diff_t len = segment_length<ArgumentId>(data);
183 diff_t num_blocks = len / chunk_size;
184 if (num_blocks * chunk_size < len)
190 set_sycl_dim<BlockDim>(dims.group, num_blocks);
193 set_sycl_dim<BlockDim>(dims.min_groups, num_blocks);
197 using data_t = camp::decay<Data>;
198 data_t private_data = data;
201 auto& segment = camp::get<ArgumentId>(private_data.segment_tuple);
204 segment = segment.slice(0, chunk_size);
207 LaunchDims enclosed_dims =
208 enclosed_stmts_t::calculateDimensions(private_data);
210 return dims.max(enclosed_dims);
219 template<
typename Data,
220 camp::idx_t ArgumentId,
221 camp::idx_t chunk_size,
223 typename... EnclosedStmts,
225 struct SyclStatementExecutor<Data,
226 statement::Tile<ArgumentId,
227 RAJA::tile_fixed<chunk_size>,
228 sycl_group_012_loop<BlockDim>,
235 using enclosed_stmts_t = SyclStatementListExecutor<Data, stmt_list_t, Types>;
237 using diff_t = segment_diff_type<ArgumentId, Data>;
240 ::sycl::nd_item<3> item,
244 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
247 using segment_t = camp::decay<decltype(segment)>;
248 segment_t orig_segment = segment;
251 diff_t len = segment.end() - segment.begin();
252 diff_t i_init = item.get_group(BlockDim) * chunk_size;
253 diff_t i_stride = item.get_group_range(BlockDim) * chunk_size;
256 for (diff_t i = i_init; i < len; i += i_stride)
260 segment = orig_segment.slice(i, chunk_size);
263 enclosed_stmts_t::exec(data, item, thread_active);
267 segment = orig_segment;
270 static inline LaunchDims calculateDimensions(Data
const& data)
274 diff_t len = segment_length<ArgumentId>(data);
275 diff_t num_blocks = len / chunk_size;
276 if (num_blocks * chunk_size < len)
282 set_sycl_dim<BlockDim>(dims.group, num_blocks);
286 using data_t = camp::decay<Data>;
287 data_t private_data = data;
290 auto& segment = camp::get<ArgumentId>(private_data.segment_tuple);
293 segment = segment.slice(0, chunk_size);
296 LaunchDims enclosed_dims =
297 enclosed_stmts_t::calculateDimensions(private_data);
299 return dims.max(enclosed_dims);
308 template<
typename Data,
309 camp::idx_t ArgumentId,
310 camp::idx_t chunk_size,
312 typename... EnclosedStmts,
314 struct SyclStatementExecutor<Data,
315 statement::Tile<ArgumentId,
316 RAJA::tile_fixed<chunk_size>,
317 sycl_local_012_direct<ThreadDim>,
324 using enclosed_stmts_t = SyclStatementListExecutor<Data, stmt_list_t, Types>;
326 using diff_t = segment_diff_type<ArgumentId, Data>;
329 ::sycl::nd_item<3> item,
333 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
336 using segment_t = camp::decay<decltype(segment)>;
337 segment_t orig_segment = segment;
340 diff_t len = segment.end() - segment.begin();
341 diff_t i = item.get_local_id(ThreadDim) * chunk_size;
345 bool have_work = i < len;
348 diff_t slice_size = have_work ? chunk_size : 0;
349 segment = orig_segment.slice(i, slice_size);
352 enclosed_stmts_t::exec(data, item, thread_active && have_work);
355 segment = orig_segment;
358 static inline LaunchDims calculateDimensions(Data
const& data)
362 diff_t len = segment_length<ArgumentId>(data);
363 diff_t num_threads = len / chunk_size;
364 if (num_threads * chunk_size < len)
370 set_sycl_dim<ThreadDim>(dims.local, num_threads);
371 set_sycl_dim<ThreadDim>(dims.min_locals, num_threads);
374 using data_t = camp::decay<Data>;
375 data_t private_data = data;
378 auto& segment = camp::get<ArgumentId>(private_data.segment_tuple);
381 segment = segment.slice(0, chunk_size);
384 LaunchDims enclosed_dims =
385 enclosed_stmts_t::calculateDimensions(private_data);
387 return (dims.max(enclosed_dims));
396 template<
typename Data,
397 camp::idx_t ArgumentId,
398 camp::idx_t chunk_size,
400 typename... EnclosedStmts,
402 struct SyclStatementExecutor<Data,
403 statement::Tile<ArgumentId,
404 RAJA::tile_fixed<chunk_size>,
405 sycl_local_012_loop<ThreadDim>,
412 using enclosed_stmts_t = SyclStatementListExecutor<Data, stmt_list_t, Types>;
414 using diff_t = segment_diff_type<ArgumentId, Data>;
417 ::sycl::nd_item<3> item,
421 auto& segment = camp::get<ArgumentId>(data.segment_tuple);
424 using segment_t = camp::decay<decltype(segment)>;
425 segment_t orig_segment = segment;
428 diff_t len = segment_length<ArgumentId>(data);
429 diff_t i_init = item.get_local_id(ThreadDim) * chunk_size;
430 diff_t i_stride = item.get_group_range(ThreadDim) * chunk_size;
433 for (diff_t ii = 0; ii < len; ii += i_stride)
435 diff_t i = ii + i_init;
439 bool have_work = i < len;
442 diff_t slice_size = have_work ? chunk_size : 0;
443 segment = orig_segment.slice(i, slice_size);
446 enclosed_stmts_t::exec(data, item, thread_active && have_work);
450 segment = orig_segment;
453 static inline LaunchDims calculateDimensions(Data
const& data)
457 diff_t len = segment_length<ArgumentId>(data);
458 diff_t num_threads = len / chunk_size;
459 if (num_threads * chunk_size < len)
463 num_threads =
std::max(num_threads, (diff_t)1);
466 set_sycl_dim<ThreadDim>(dims.local, num_threads);
467 set_sycl_dim<ThreadDim>(dims.min_locals, 1);
470 using data_t = camp::decay<Data>;
471 data_t private_data = data;
474 auto& segment = camp::get<ArgumentId>(private_data.segment_tuple);
477 segment = segment.slice(0, chunk_size);
480 LaunchDims enclosed_dims =
481 enclosed_stmts_t::calculateDimensions(private_data);
483 return (dims.max(enclosed_dims));
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
camp::list< Stmts... > StatementList
Definition: StatementList.hpp:41
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
Header file for tile wrapper and iterator.
Header file for loop kernel internals.
Header file for RAJA type definitions.