RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
internal.hpp
Go to the documentation of this file.
1 
12 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
13 // Copyright (c) Lawrence Livermore National Security, LLC and other
14 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
15 // files for dates and other details. No copyright assignment is required
16 // to contribute to RAJA.
17 //
18 // SPDX-License-Identifier: (BSD-3-Clause)
19 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
20 
21 #ifndef RAJA_policy_sycl_kernel_internal_HPP
22 #define RAJA_policy_sycl_kernel_internal_HPP
23 
24 #include "RAJA/config.hpp"
25 
26 #if defined(RAJA_ENABLE_SYCL)
27 
28 #include <cassert>
29 #include <climits>
30 
31 #include "camp/camp.hpp"
32 
33 #include "RAJA/pattern/kernel.hpp"
34 
35 #include "RAJA/util/macros.hpp"
36 #include "RAJA/util/types.hpp"
37 
40 
41 namespace RAJA
42 {
43 
44 namespace internal
45 {
46 
47 // LaunchDims and Helper functions
48 struct LaunchDims
49 {
50  sycl_dim_3_t group;
51  sycl_dim_3_t local;
52  sycl_dim_3_t global;
53  sycl_dim_3_t min_groups;
54  sycl_dim_3_t min_locals;
55 
56  RAJA_INLINE
57 
59  LaunchDims()
60  : group {0, 0, 0},
61  local {1, 1, 1},
62  global {1, 1, 1},
63  min_groups {0, 0, 0},
64  min_locals {0, 0, 0}
65  {}
66 
67  RAJA_INLINE
68 
70  LaunchDims(LaunchDims const& c)
71  : group(c.group),
72  local(c.local),
73  global(c.global)
74  {}
75 
76  RAJA_INLINE
77  LaunchDims max(LaunchDims const& c) const
78  {
79  LaunchDims result;
80 
81  result.group.x = std::max(c.group.x, group.x);
82  result.group.y = std::max(c.group.y, group.y);
83  result.group.z = std::max(c.group.z, group.z);
84 
85  result.local.x = std::max(c.local.x, local.x);
86  result.local.y = std::max(c.local.y, local.y);
87  result.local.z = std::max(c.local.z, local.z);
88 
89  result.global.x = std::max(c.global.x, global.x);
90  result.global.y = std::max(c.global.y, global.y);
91  result.global.z = std::max(c.global.z, global.z);
92 
93  return result;
94  }
95 
96  ::sycl::nd_range<3> fit_nd_range(::sycl::queue* q)
97  {
98 
99  sycl_dim_3_t launch_global;
100 
101  sycl_dim_3_t launch_local {1, 1, 1};
102  launch_local.x = std::max(launch_local.x, local.x);
103  launch_local.y = std::max(launch_local.y, local.y);
104  launch_local.z = std::max(launch_local.z, local.z);
105 
106  ::sycl::device dev = q->get_device();
107 
108  auto max_work_group_size =
109  dev.get_info<::sycl::info::device::max_work_group_size>();
110 
111  if (launch_local.x > max_work_group_size)
112  {
113  launch_local.x = max_work_group_size;
114  }
115  if (launch_local.y > max_work_group_size)
116  {
117  launch_local.y = max_work_group_size;
118  }
119  if (launch_local.z > max_work_group_size)
120  {
121  launch_local.z = max_work_group_size;
122  }
123 
124 
125  // Make sure the multiple of locals fits
126  // Prefer larger z -> y -> x
127  if (launch_local.x * launch_local.y * launch_local.z > max_work_group_size)
128  {
129  unsigned long remaining = 1;
130  // local z cannot be > max_wrk from above
131  // if equal then remaining is 1, on handle <
132  if (max_work_group_size > launch_local.z)
133  {
134  // keep local z
135  remaining = max_work_group_size / launch_local.z;
136  }
137  if (remaining >= launch_local.y)
138  {
139  // keep local y
140  remaining = remaining / launch_local.y;
141  }
142  else
143  {
144  launch_local.y = remaining;
145  remaining = remaining / launch_local.y;
146  }
147  if (remaining < launch_local.x)
148  {
149  launch_local.x = remaining;
150  }
151  }
152 
153 
154  // User gave group policy, use to calculate global space
155  if (group.x != 0 || group.y != 0 || group.z != 0)
156  {
157  sycl_dim_3_t launch_group {1, 1, 1};
158  launch_group.x = std::max(launch_group.x, group.x);
159  launch_group.y = std::max(launch_group.y, group.y);
160  launch_group.z = std::max(launch_group.z, group.z);
161 
162  launch_global.x = launch_local.x * launch_group.x;
163  launch_global.y = launch_local.y * launch_group.y;
164  launch_global.z = launch_local.z * launch_group.z;
165  }
166  else
167  {
168  launch_global.x =
169  launch_local.x * ((global.x + (launch_local.x - 1)) / launch_local.x);
170  launch_global.y =
171  launch_local.y * ((global.y + (launch_local.y - 1)) / launch_local.y);
172  launch_global.z =
173  launch_local.z * ((global.z + (launch_local.z - 1)) / launch_local.z);
174  }
175 
176 
177  if (launch_global.x % launch_local.x != 0)
178  {
179  launch_global.x =
180  ((launch_global.x / launch_local.x) + 1) * launch_local.x;
181  }
182  if (launch_global.y % launch_local.y != 0)
183  {
184  launch_global.y =
185  ((launch_global.y / launch_local.y) + 1) * launch_local.y;
186  }
187  if (launch_global.z % launch_local.z != 0)
188  {
189  launch_global.z =
190  ((launch_global.z / launch_local.z) + 1) * launch_local.z;
191  }
192 
193  ::sycl::range<3> ret_th = {launch_local.x, launch_local.y, launch_local.z};
194  ::sycl::range<3> ret_gl = {launch_global.x, launch_global.y,
195  launch_global.z};
196 
197  return ::sycl::nd_range<3>(ret_gl, ret_th);
198  }
199 };
200 
201 template<camp::idx_t cur_stmt, camp::idx_t num_stmts, typename StmtList>
202 struct SyclStatementListExecutorHelper
203 {
204 
205  using next_helper_t =
206  SyclStatementListExecutorHelper<cur_stmt + 1, num_stmts, StmtList>;
207 
208  using cur_stmt_t = camp::at_v<StmtList, cur_stmt>;
209 
210  template<typename Data>
211  inline static RAJA_DEVICE void exec(Data& data,
212  ::sycl::nd_item<3> item,
213  bool thread_active)
214  {
215  // Execute stmt
216  cur_stmt_t::exec(data, item, thread_active);
217 
218  // Execute next stmt
219  next_helper_t::exec(data, item, thread_active);
220  }
221 
222  template<typename Data>
223  inline static LaunchDims calculateDimensions(Data& data)
224  {
225  // Compute this statements launch dimensions
226  LaunchDims statement_dims = cur_stmt_t::calculateDimensions(data);
227 
228  // call the next statement in the list
229  LaunchDims next_dims = next_helper_t::calculateDimensions(data);
230 
231  // Return the maximum of the two
232  return statement_dims.max(next_dims);
233  }
234 };
235 
236 template<camp::idx_t num_stmts, typename StmtList>
237 struct SyclStatementListExecutorHelper<num_stmts, num_stmts, StmtList>
238 {
239 
240  template<typename Data>
241  inline static RAJA_DEVICE void exec(Data&,
242  ::sycl::nd_item<3> RAJA_UNUSED_ARG(item),
243  bool)
244  {
245  // nop terminator
246  }
247 
248  template<typename Data>
249  inline static LaunchDims calculateDimensions(Data&)
250  {
251  return LaunchDims();
252  }
253 };
254 
255 template<typename Data, typename Policy, typename Types>
256 struct SyclStatementExecutor;
257 
258 template<typename Data, typename StmtList, typename Types>
259 struct SyclStatementListExecutor;
260 
261 template<typename Data, typename... Stmts, typename Types>
262 struct SyclStatementListExecutor<Data, StatementList<Stmts...>, Types>
263 {
264 
265  using enclosed_stmts_t =
266  camp::list<SyclStatementExecutor<Data, Stmts, Types>...>;
267 
268  static constexpr size_t num_stmts = sizeof...(Stmts);
269 
270  static inline RAJA_DEVICE void exec(Data& data,
271  ::sycl::nd_item<3> item,
272  bool thread_active)
273  {
274  // Execute statements in order with helper class
275  SyclStatementListExecutorHelper<0, num_stmts, enclosed_stmts_t>::exec(
276  data, item, thread_active);
277  }
278 
279  static inline LaunchDims calculateDimensions(Data const& data)
280  {
281  // Compute this statements launch dimensions
282  return SyclStatementListExecutorHelper<
283  0, num_stmts, enclosed_stmts_t>::calculateDimensions(data);
284  }
285 };
286 
287 template<typename StmtList, typename Data, typename Types>
288 using sycl_statement_list_executor_t =
289  SyclStatementListExecutor<Data, StmtList, Types>;
290 
291 } // namespace internal
292 } // namespace RAJA
293 
294 #endif // closing endif for RAJA_ENABLE_SYCL guard
295 
296 #endif // closing endif for header file include guard
Header file defining prototypes for routines used to manage memory for SYCL reductions and other oper...
Header file for common RAJA internal macro definitions.
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_UNUSED_ARG(x)
Definition: macros.hpp:97
#define RAJA_DEVICE
Definition: macros.hpp:66
camp::list< Stmts... > StatementList
Definition: StatementList.hpp:41
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE Result max(Args... args)
Definition: foldl.hpp:155
RAJA header file containing user interface for RAJA::kernel.
Header file containing RAJA SYCL policy definitions.
Header file for RAJA type definitions.