21 #ifndef RAJA_policy_sycl_kernel_internal_HPP
22 #define RAJA_policy_sycl_kernel_internal_HPP
24 #include "RAJA/config.hpp"
26 #if defined(RAJA_ENABLE_SYCL)
31 #include "camp/camp.hpp"
53 sycl_dim_3_t min_groups;
54 sycl_dim_3_t min_locals;
70 LaunchDims(LaunchDims
const& c)
77 LaunchDims
max(LaunchDims
const& c)
const
81 result.group.x =
std::max(c.group.x, group.x);
82 result.group.y =
std::max(c.group.y, group.y);
83 result.group.z =
std::max(c.group.z, group.z);
85 result.local.x =
std::max(c.local.x, local.x);
86 result.local.y =
std::max(c.local.y, local.y);
87 result.local.z =
std::max(c.local.z, local.z);
89 result.global.x =
std::max(c.global.x, global.x);
90 result.global.y =
std::max(c.global.y, global.y);
91 result.global.z =
std::max(c.global.z, global.z);
96 ::sycl::nd_range<3> fit_nd_range(::sycl::queue* q)
99 sycl_dim_3_t launch_global;
101 sycl_dim_3_t launch_local {1, 1, 1};
102 launch_local.x =
std::max(launch_local.x, local.x);
103 launch_local.y =
std::max(launch_local.y, local.y);
104 launch_local.z =
std::max(launch_local.z, local.z);
106 ::sycl::device dev = q->get_device();
108 auto max_work_group_size =
109 dev.get_info<::sycl::info::device::max_work_group_size>();
111 if (launch_local.x > max_work_group_size)
113 launch_local.x = max_work_group_size;
115 if (launch_local.y > max_work_group_size)
117 launch_local.y = max_work_group_size;
119 if (launch_local.z > max_work_group_size)
121 launch_local.z = max_work_group_size;
127 if (launch_local.x * launch_local.y * launch_local.z > max_work_group_size)
129 unsigned long remaining = 1;
132 if (max_work_group_size > launch_local.z)
135 remaining = max_work_group_size / launch_local.z;
137 if (remaining >= launch_local.y)
140 remaining = remaining / launch_local.y;
144 launch_local.y = remaining;
145 remaining = remaining / launch_local.y;
147 if (remaining < launch_local.x)
149 launch_local.x = remaining;
155 if (group.x != 0 || group.y != 0 || group.z != 0)
157 sycl_dim_3_t launch_group {1, 1, 1};
158 launch_group.x =
std::max(launch_group.x, group.x);
159 launch_group.y =
std::max(launch_group.y, group.y);
160 launch_group.z =
std::max(launch_group.z, group.z);
162 launch_global.x = launch_local.x * launch_group.x;
163 launch_global.y = launch_local.y * launch_group.y;
164 launch_global.z = launch_local.z * launch_group.z;
169 launch_local.x * ((global.x + (launch_local.x - 1)) / launch_local.x);
171 launch_local.y * ((global.y + (launch_local.y - 1)) / launch_local.y);
173 launch_local.z * ((global.z + (launch_local.z - 1)) / launch_local.z);
177 if (launch_global.x % launch_local.x != 0)
180 ((launch_global.x / launch_local.x) + 1) * launch_local.x;
182 if (launch_global.y % launch_local.y != 0)
185 ((launch_global.y / launch_local.y) + 1) * launch_local.y;
187 if (launch_global.z % launch_local.z != 0)
190 ((launch_global.z / launch_local.z) + 1) * launch_local.z;
193 ::sycl::range<3> ret_th = {launch_local.x, launch_local.y, launch_local.z};
194 ::sycl::range<3> ret_gl = {launch_global.x, launch_global.y,
197 return ::sycl::nd_range<3>(ret_gl, ret_th);
201 template<camp::
idx_t cur_stmt, camp::
idx_t num_stmts,
typename StmtList>
202 struct SyclStatementListExecutorHelper
205 using next_helper_t =
206 SyclStatementListExecutorHelper<cur_stmt + 1, num_stmts, StmtList>;
208 using cur_stmt_t = camp::at_v<StmtList, cur_stmt>;
210 template<
typename Data>
212 ::sycl::nd_item<3> item,
216 cur_stmt_t::exec(data, item, thread_active);
219 next_helper_t::exec(data, item, thread_active);
222 template<
typename Data>
223 inline static LaunchDims calculateDimensions(Data& data)
226 LaunchDims statement_dims = cur_stmt_t::calculateDimensions(data);
229 LaunchDims next_dims = next_helper_t::calculateDimensions(data);
232 return statement_dims.max(next_dims);
236 template<camp::
idx_t num_stmts,
typename StmtList>
237 struct SyclStatementListExecutorHelper<num_stmts, num_stmts, StmtList>
240 template<
typename Data>
248 template<
typename Data>
249 inline static LaunchDims calculateDimensions(Data&)
255 template<
typename Data,
typename Policy,
typename Types>
256 struct SyclStatementExecutor;
258 template<
typename Data,
typename StmtList,
typename Types>
259 struct SyclStatementListExecutor;
261 template<
typename Data,
typename... Stmts,
typename Types>
262 struct SyclStatementListExecutor<Data,
StatementList<Stmts...>, Types>
265 using enclosed_stmts_t =
266 camp::list<SyclStatementExecutor<Data, Stmts, Types>...>;
268 static constexpr
size_t num_stmts =
sizeof...(Stmts);
271 ::sycl::nd_item<3> item,
275 SyclStatementListExecutorHelper<0, num_stmts, enclosed_stmts_t>::exec(
276 data, item, thread_active);
279 static inline LaunchDims calculateDimensions(Data
const& data)
282 return SyclStatementListExecutorHelper<
283 0, num_stmts, enclosed_stmts_t>::calculateDimensions(data);
287 template<
typename StmtList,
typename Data,
typename Types>
288 using sycl_statement_list_executor_t =
289 SyclStatementListExecutor<Data, StmtList, Types>;
Header file defining prototypes for routines used to manage memory for SYCL reductions and other oper...
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_UNUSED_ARG(x)
Definition: macros.hpp:97
#define RAJA_DEVICE
Definition: macros.hpp:66
camp::list< Stmts... > StatementList
Definition: StatementList.hpp:41
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
RAJA header file containing user interface for RAJA::kernel.
Header file containing RAJA SYCL policy definitions.
Header file for RAJA type definitions.