RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
TileTCount.hpp
Go to the documentation of this file.
1 
12 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
13 // Copyright (c) Lawrence Livermore National Security, LLC and other
14 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
15 // files for dates and other details. No copyright assignment is required
16 // to contribute to RAJA.
17 //
18 // SPDX-License-Identifier: (BSD-3-Clause)
19 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
20 
21 
22 #ifndef RAJA_policy_sycl_kernel_TileTCount_HPP
23 #define RAJA_policy_sycl_kernel_TileTCount_HPP
24 
25 #include "RAJA/config.hpp"
26 
27 #if defined(RAJA_ENABLE_SYCL)
28 
29 #include <iostream>
30 #include <type_traits>
31 
32 #include "camp/camp.hpp"
33 #include "camp/concepts.hpp"
34 #include "camp/tuple.hpp"
35 
36 #include "RAJA/util/macros.hpp"
37 #include "RAJA/util/types.hpp"
38 
41 
42 namespace RAJA
43 {
44 namespace internal
45 {
46 
52 template<typename Data,
53  camp::idx_t ArgumentId,
54  typename ParamId,
55  typename TPol,
56  typename... EnclosedStmts,
57  typename Types>
58 struct SyclStatementExecutor<
59  Data,
60  statement::
61  TileTCount<ArgumentId, ParamId, TPol, seq_exec, EnclosedStmts...>,
62  Types>
63  : public SyclStatementExecutor<
64  Data,
65  statement::Tile<ArgumentId, TPol, seq_exec, EnclosedStmts...>,
66  Types>
67 {
68 
69  using Base = SyclStatementExecutor<
70  Data,
71  statement::Tile<ArgumentId, TPol, seq_exec, EnclosedStmts...>,
72  Types>;
73 
74  using typename Base::diff_t;
75  using typename Base::enclosed_stmts_t;
76 
77  static inline RAJA_DEVICE void exec(Data& data,
78  ::sycl::nd_item<3> item,
79  bool thread_active)
80  {
81  // Get the segment referenced by this Tile statement
82  auto& segment = camp::get<ArgumentId>(data.segment_tuple);
83 
84  // Keep copy of original segment, so we can restore it
85  using segment_t = camp::decay<decltype(segment)>;
86  segment_t orig_segment = segment;
87 
88  diff_t chunk_size = TPol::chunk_size;
89 
90  // compute trip count
91  diff_t len = segment.end() - segment.begin();
92 
93  // Iterate through tiles
94  for (diff_t i = 0, t = 0; i < len; i += chunk_size, ++t)
95  {
96 
97  // Assign our new tiled segment
98  segment = orig_segment.slice(i, chunk_size);
99  data.template assign_param<ParamId>(t);
100 
101  // execute enclosed statements
102  enclosed_stmts_t::exec(data, item, thread_active);
103  }
104 
105  // Set range back to original values
106  segment = orig_segment;
107  }
108 };
109 
115 template<typename Data,
116  camp::idx_t ArgumentId,
117  typename ParamId,
118  camp::idx_t chunk_size,
119  int BlockDim,
120  typename... EnclosedStmts,
121  typename Types>
122 struct SyclStatementExecutor<
123  Data,
124  statement::TileTCount<ArgumentId,
125  ParamId,
126  RAJA::tile_fixed<chunk_size>,
127  sycl_group_012_direct<BlockDim>,
128  EnclosedStmts...>,
129  Types>
130  : public SyclStatementExecutor<
131  Data,
132  statement::Tile<ArgumentId,
133  RAJA::tile_fixed<chunk_size>,
134  sycl_group_012_direct<BlockDim>,
135  EnclosedStmts...>,
136  Types>
137 {
138 
139  using Base =
140  SyclStatementExecutor<Data,
141  statement::Tile<ArgumentId,
143  sycl_group_012_direct<BlockDim>,
144  EnclosedStmts...>,
145  Types>;
146 
147  using typename Base::diff_t;
148  using typename Base::enclosed_stmts_t;
149 
150  static inline RAJA_DEVICE void exec(Data& data,
151  ::sycl::nd_item<3> item,
152  bool thread_active)
153  {
154  // Get the segment referenced by this Tile statement
155  auto& segment = camp::get<ArgumentId>(data.segment_tuple);
156 
157  using segment_t = camp::decay<decltype(segment)>;
158 
159  // compute trip count
160  diff_t len = segment.end() - segment.begin();
161  // diff_t t = get_sycl_dim<BlockDim>(blockIdx);
162  diff_t t = item.get_group(BlockDim);
163  diff_t i = t * chunk_size;
164 
165  // check have a chunk
166  if (i < len)
167  {
168 
169  // Keep copy of original segment, so we can restore it
170  segment_t orig_segment = segment;
171 
172  // Assign our new tiled segment
173  segment = orig_segment.slice(i, chunk_size);
174  data.template assign_param<ParamId>(t);
175 
176  // execute enclosed statements
177  enclosed_stmts_t::exec(data, item, thread_active);
178 
179  // Set range back to original values
180  segment = orig_segment;
181  }
182  }
183 };
184 
190 template<typename Data,
191  camp::idx_t ArgumentId,
192  typename ParamId,
193  camp::idx_t chunk_size,
194  int BlockDim,
195  typename... EnclosedStmts,
196  typename Types>
197 struct SyclStatementExecutor<
198  Data,
199  statement::TileTCount<ArgumentId,
200  ParamId,
201  RAJA::tile_fixed<chunk_size>,
202  sycl_group_012_loop<BlockDim>,
203  EnclosedStmts...>,
204  Types>
205  : public SyclStatementExecutor<
206  Data,
207  statement::Tile<ArgumentId,
208  RAJA::tile_fixed<chunk_size>,
209  sycl_group_012_loop<BlockDim>,
210  EnclosedStmts...>,
211  Types>
212 {
213 
214  using Base =
215  SyclStatementExecutor<Data,
216  statement::Tile<ArgumentId,
218  sycl_group_012_loop<BlockDim>,
219  EnclosedStmts...>,
220  Types>;
221 
222  using typename Base::diff_t;
223  using typename Base::enclosed_stmts_t;
224 
225  static inline RAJA_DEVICE void exec(Data& data,
226  ::sycl::nd_item<3> item,
227  bool thread_active)
228  {
229  // Get the segment referenced by this Tile statement
230  auto& segment = camp::get<ArgumentId>(data.segment_tuple);
231 
232  // Keep copy of original segment, so we can restore it
233  using segment_t = camp::decay<decltype(segment)>;
234  segment_t orig_segment = segment;
235 
236  // compute trip count
237  diff_t len = segment.end() - segment.begin();
238  diff_t t_init = item.get_group(BlockDim);
239  diff_t i_init = t_init * chunk_size;
240  diff_t t_stride = item.get_group_range(BlockDim);
241  diff_t i_stride = t_stride * chunk_size;
242 
243  // Iterate through grid stride of chunks
244  for (diff_t i = i_init, t = t_init; i < len; i += i_stride, t += t_stride)
245  {
246 
247  // Assign our new tiled segment
248  segment = orig_segment.slice(i, chunk_size);
249  data.template assign_param<ParamId>(t);
250 
251  // execute enclosed statements
252  enclosed_stmts_t::exec(data, item, thread_active);
253  }
254 
255  // Set range back to original values
256  segment = orig_segment;
257  }
258 };
259 
265 template<typename Data,
266  camp::idx_t ArgumentId,
267  typename ParamId,
268  camp::idx_t chunk_size,
269  int ThreadDim,
270  typename... EnclosedStmts,
271  typename Types>
272 struct SyclStatementExecutor<
273  Data,
274  statement::TileTCount<ArgumentId,
275  ParamId,
276  RAJA::tile_fixed<chunk_size>,
277  sycl_local_012_direct<ThreadDim>,
278  EnclosedStmts...>,
279  Types>
280  : public SyclStatementExecutor<
281  Data,
282  statement::Tile<ArgumentId,
283  RAJA::tile_fixed<chunk_size>,
284  sycl_local_012_direct<ThreadDim>,
285  EnclosedStmts...>,
286  Types>
287 {
288 
289  using Base =
290  SyclStatementExecutor<Data,
291  statement::Tile<ArgumentId,
293  sycl_local_012_direct<ThreadDim>,
294  EnclosedStmts...>,
295  Types>;
296 
297  using typename Base::diff_t;
298  using typename Base::enclosed_stmts_t;
299 
300  static inline RAJA_DEVICE void exec(Data& data,
301  ::sycl::nd_item<3> item,
302  bool thread_active)
303  {
304  // Get the segment referenced by this Tile statement
305  auto& segment = camp::get<ArgumentId>(data.segment_tuple);
306 
307  // Keep copy of original segment, so we can restore it
308  using segment_t = camp::decay<decltype(segment)>;
309  segment_t orig_segment = segment;
310 
311  // compute trip count
312  diff_t len = segment.end() - segment.begin();
313  // diff_t t = get_sycl_dim<ThreadDim>(threadIdx);
314  diff_t t = item.get_local_id(ThreadDim);
315  diff_t i = t * chunk_size;
316 
317  // execute enclosed statements if any thread will
318  // but mask off threads without work
319  bool have_work = i < len;
320 
321  // Assign our new tiled segment
322  diff_t slice_size = have_work ? chunk_size : 0;
323  segment = orig_segment.slice(i, slice_size);
324  data.template assign_param<ParamId>(t);
325 
326  // execute enclosed statements
327  enclosed_stmts_t::exec(data, item, thread_active && have_work);
328 
329  // Set range back to original values
330  segment = orig_segment;
331  }
332 };
333 
339 template<typename Data,
340  camp::idx_t ArgumentId,
341  typename ParamId,
342  camp::idx_t chunk_size,
343  int ThreadDim,
344  typename... EnclosedStmts,
345  typename Types>
346 struct SyclStatementExecutor<
347  Data,
348  statement::TileTCount<ArgumentId,
349  ParamId,
350  RAJA::tile_fixed<chunk_size>,
351  sycl_local_012_loop<ThreadDim>,
352  EnclosedStmts...>,
353  Types>
354  : public SyclStatementExecutor<
355  Data,
356  statement::Tile<ArgumentId,
357  RAJA::tile_fixed<chunk_size>,
358  sycl_local_012_loop<ThreadDim>,
359  EnclosedStmts...>,
360  Types>
361 {
362 
363  using Base =
364  SyclStatementExecutor<Data,
365  statement::Tile<ArgumentId,
367  sycl_local_012_loop<ThreadDim>,
368  EnclosedStmts...>,
369  Types>;
370 
371  using typename Base::diff_t;
372  using typename Base::enclosed_stmts_t;
373 
374  static inline RAJA_DEVICE void exec(Data& data,
375  ::sycl::nd_item<3> item,
376  bool thread_active)
377  {
378  // Get the segment referenced by this Tile statement
379  auto& segment = camp::get<ArgumentId>(data.segment_tuple);
380 
381  // Keep copy of original segment, so we can restore it
382  using segment_t = camp::decay<decltype(segment)>;
383  segment_t orig_segment = segment;
384 
385  // compute trip count
386  diff_t len = segment_length<ArgumentId>(data);
387  // diff_t t_init = get_sycl_dim<ThreadDim>(threadIdx);
388  diff_t t_init = item.get_local_id(ThreadDim);
389  diff_t i_init = t_init * chunk_size;
390  // diff_t t_stride = get_sycl_dim<ThreadDim>(blockDim);
391  diff_t t_stride = item.get_local_range(ThreadDim);
392  diff_t i_stride = t_stride * chunk_size;
393 
394  // Iterate through grid stride of chunks
395  for (diff_t ii = 0, t = t_init; ii < len; ii += i_stride, t += t_stride)
396  {
397  diff_t i = ii + i_init;
398 
399  // execute enclosed statements if any thread will
400  // but mask off threads without work
401  bool have_work = i < len;
402 
403  // Assign our new tiled segment
404  diff_t slice_size = have_work ? chunk_size : 0;
405  segment = orig_segment.slice(i, slice_size);
406  data.template assign_param<ParamId>(t);
407 
408  // execute enclosed statements
409  enclosed_stmts_t::exec(data, item, thread_active && have_work);
410  }
411 
412  // Set range back to original values
413  segment = orig_segment;
414  }
415 };
416 
417 } // end namespace internal
418 } // end namespace RAJA
419 
420 #endif // RAJA_ENABLE_SYCL
421 #endif /* RAJA_policy_sycl_kernel_TileTCount_HPP */
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
Definition: AlignedRangeIndexSetBuilders.cpp:35
Header file for tile wrapper and iterator.
Header file for loop kernel internals.
! tag for a tiling loop
Definition: Tile.hpp:72
Header file for RAJA type definitions.