RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
launch.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_launch_sycl_HPP
21 #define RAJA_pattern_launch_sycl_HPP
22 
27 // #include "RAJA/policy/sycl/raja_syclerrchk.hpp"
28 #include "RAJA/util/resource.hpp"
29 
30 namespace RAJA
31 {
32 
33 template<bool async>
34 struct LaunchExecute<RAJA::sycl_launch_t<async, 0>>
35 {
36  template<typename LoopBody, typename ReduceParams>
37  static concepts::enable_if_t<
38  resources::EventProxy<resources::Resource>,
40  exec(RAJA::resources::Resource res,
41  const LaunchParams& launch_params,
42  LoopBody&& loop_body,
43  ReduceParams launch_reducers)
44  {
45  using EXEC_POL = RAJA::sycl_launch_t<async, 0>;
46  using LOOP_BODY = camp::decay<LoopBody>;
47  // Deduce at compile time if lbody is trivially constructible and if user
48  // has supplied parameters. These will be used to determine which sycl
49  // launch to configure below.
50  constexpr bool is_parampack_empty =
52  constexpr bool is_lbody_trivially_copyable =
53  std::is_trivially_copyable<LOOP_BODY>::value;
54  EXEC_POL pol {};
55 
56  /*Get the queue from concrete resource */
57  ::sycl::queue* q = res.get<camp::resources::Sycl>().get_queue();
58 
59  if constexpr (!is_parampack_empty)
60  {
62  }
63 
64  //
65  // Compute the number of blocks and threads
66  //
67  const ::sycl::range<3> blockSize(launch_params.threads.value[2],
68  launch_params.threads.value[1],
69  launch_params.threads.value[0]);
70 
71  const ::sycl::range<3> gridSize(
72  launch_params.threads.value[2] * launch_params.teams.value[2],
73  launch_params.threads.value[1] * launch_params.teams.value[1],
74  launch_params.threads.value[0] * launch_params.teams.value[0]);
75 
76  // Only launch kernel if we have something to iterate over
77  constexpr int zero = 0;
78  if (launch_params.threads.value[0] <= zero ||
79  launch_params.threads.value[1] <= zero ||
80  launch_params.threads.value[2] <= zero ||
81  launch_params.teams.value[0] <= zero ||
82  launch_params.teams.value[1] <= zero ||
83  launch_params.teams.value[2] <= zero)
84  {
85  return resources::EventProxy<resources::Resource>(res);
86  }
87 
88 
89  using LOOP_BODY = camp::decay<LoopBody>;
90  LOOP_BODY* lbody = nullptr;
91  //
92  // Kernel body is nontrivially copyable, create space on device and copy
93  // to Workaround until "is_device_copyable" is supported
94  //
95  if constexpr (!is_lbody_trivially_copyable)
96  {
97  lbody = (LOOP_BODY*)::sycl::malloc_device(sizeof(LOOP_BODY), *q);
98  q->memcpy(lbody, &loop_body, sizeof(LOOP_BODY)).wait();
99  }
100  // Both the parallel_for call, combinations, and resolution are all
101  // unique to the parameter case, so we make a constexpr branch here
102  if constexpr (!is_parampack_empty)
103  {
104  auto combiner = [](ReduceParams x, ReduceParams y) {
106  return x;
107  };
108 
109  ReduceParams* res = ::sycl::malloc_shared<ReduceParams>(1, *q);
111  auto reduction = ::sycl::reduction(res, launch_reducers, combiner);
112 
113  q->submit([&](::sycl::handler& h) {
114  auto s_vec =
115  ::sycl::local_accessor<char, 1>(launch_params.shared_mem_size, h);
116 
117  h.parallel_for(
118  ::sycl::nd_range<3>(gridSize, blockSize), reduction,
119  [=](::sycl::nd_item<3> itm, auto& red) {
121  ctx.itm = &itm;
122 
123  // Point to shared memory
124  ctx.shared_mem_ptr =
125  s_vec.get_multi_ptr<::sycl::access::decorated::yes>().get();
126 
127  ReduceParams fp;
129  if constexpr (is_lbody_trivially_copyable)
130  {
131  RAJA::expt::invoke_body(fp, loop_body, ctx);
132  }
133  else
134  {
135  RAJA::expt::invoke_body(fp, *lbody, ctx);
136  }
137 
138  red.combine(fp);
139  });
140  }).wait(); // Need to wait for completion to free memory
141 
143  *res);
144  ::sycl::free(res, *q);
145  ::sycl::free(lbody, *q);
147  }
148  else
149  {
150  q->submit([&](::sycl::handler& h) {
151  auto s_vec =
152  ::sycl::local_accessor<char, 1>(launch_params.shared_mem_size, h);
153 
154  h.parallel_for(
155  ::sycl::nd_range<3>(gridSize, blockSize),
156  [=](::sycl::nd_item<3> itm) {
158  ctx.itm = &itm;
159 
160  // Point to shared memory
161  ctx.shared_mem_ptr =
162  s_vec.get_multi_ptr<::sycl::access::decorated::yes>().get();
163  if constexpr (is_lbody_trivially_copyable)
164  {
165  loop_body(ctx);
166  }
167  else
168  {
169  (*lbody)(ctx);
170  }
171  });
172  });
173 
174  if (!async)
175  {
176  q->wait();
177  }
178  }
179 
180  return resources::EventProxy<resources::Resource>(res);
181  }
182 };
183 
184 /*
185  SYCL global thread mapping
186 */
187 template<int... DIM>
189 
193 
194 template<typename SEGMENT, int DIM>
195 struct LoopExecute<sycl_global_item<DIM>, SEGMENT>
196 {
197 
198  template<typename BODY>
199  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
200  SEGMENT const& segment,
201  BODY const& body)
202  {
203 
204  const int len = segment.end() - segment.begin();
205  {
206  const int tx = ctx.itm->get_group(DIM) * ctx.itm->get_local_range(DIM) +
207  ctx.itm->get_local_id(DIM);
208 
209  if (tx < len) body(*(segment.begin() + tx));
210  }
211  }
212 };
213 
220 
221 template<typename SEGMENT, int DIM0, int DIM1>
222 struct LoopExecute<sycl_global_item<DIM0, DIM1>, SEGMENT>
223 {
224 
225  template<typename BODY>
226  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
227  SEGMENT const& segment0,
228  SEGMENT const& segment1,
229  BODY const& body)
230  {
231  const int len1 = segment1.end() - segment1.begin();
232  const int len0 = segment0.end() - segment0.begin();
233  {
234  const int tx = ctx.itm->get_group(DIM0) * ctx.itm->get_local_range(DIM0) +
235  ctx.itm->get_local_id(DIM0);
236 
237  const int ty = ctx.itm->get_group(DIM1) * ctx.itm->get_local_range(DIM1) +
238  ctx.itm->get_local_id(DIM1);
239 
240 
241  if (tx < len0 && ty < len1)
242  body(*(segment0.begin() + tx), *(segment1.begin() + ty));
243  }
244  }
245 };
246 
253 
254 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
255 struct LoopExecute<sycl_global_item<DIM0, DIM1, DIM2>, SEGMENT>
256 {
257 
258  template<typename BODY>
259  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
260  SEGMENT const& segment0,
261  SEGMENT const& segment1,
262  SEGMENT const& segment2,
263  BODY const& body)
264  {
265  const int len2 = segment2.end() - segment2.begin();
266  const int len1 = segment1.end() - segment1.begin();
267  const int len0 = segment0.end() - segment0.begin();
268  {
269  const int tx = ctx.itm->get_group(DIM0) * ctx.itm->get_local_range(DIM0) +
270  ctx.itm->get_local_id(DIM0);
271 
272  const int ty = ctx.itm->get_group(DIM1) * ctx.itm->get_local_range(DIM1) +
273  ctx.itm->get_local_id(DIM1);
274 
275  const int tz = ctx.itm->get_group(DIM2) * ctx.itm->get_local_range(DIM2) +
276  ctx.itm->get_local_id(DIM2);
277 
278  if (tx < len0 && ty < len1 && tz < len2)
279  body(*(segment0.begin() + tx), *(segment1.begin() + ty),
280  *(segment1.begin() + ty));
281  }
282  }
283 };
284 
285 /*
286 Reshape threads in a block into a 1D iteration space
287 */
288 template<int... dim>
290 {};
291 
304 
317 
318 template<int... dim>
320 {};
321 
328 
341 
342 template<typename SEGMENT, int DIM0, int DIM1>
343 struct LoopExecute<sycl_flatten_group_local_direct<DIM0, DIM1>, SEGMENT>
344 {
345  template<typename BODY>
346  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
347  SEGMENT const& segment,
348  BODY const& body)
349  {
350 
351  const int len = segment.end() - segment.begin();
352  {
353  const int tx = ctx.itm->get_local_id(DIM0);
354  const int ty = ctx.itm->get_local_id(DIM1);
355  const int bx = ctx.itm->get_local_range(DIM0);
356  const int tid = tx + bx * ty;
357 
358  if (tid < len) body(*(segment.begin() + tid));
359  }
360  }
361 };
362 
363 template<typename SEGMENT, int DIM0, int DIM1>
364 struct LoopExecute<sycl_flatten_group_local_loop<DIM0, DIM1>, SEGMENT>
365 {
366  template<typename BODY>
367  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
368  SEGMENT const& segment,
369  BODY const& body)
370  {
371  const int len = segment.end() - segment.begin();
372 
373  const int tx = ctx.itm->get_local_id(DIM0);
374  const int ty = ctx.itm->get_local_id(DIM1);
375 
376  const int bx = ctx.itm->get_local_range(DIM0);
377  const int by = ctx.itm->get_local_range(DIM1);
378 
379  for (int tid = tx + bx * ty; tid < len; tid += bx * by)
380  {
381  body(*(segment.begin() + tid));
382  }
383  }
384 };
385 
386 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
387 struct LoopExecute<sycl_flatten_group_local_direct<DIM0, DIM1, DIM2>, SEGMENT>
388 {
389  template<typename BODY>
390  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
391  SEGMENT const& segment,
392  BODY const& body)
393  {
394  const int len = segment.end() - segment.begin();
395  {
396  const int tx = ctx.itm->get_local_id(DIM0);
397  const int ty = ctx.itm->get_local_id(DIM1);
398  const int tz = ctx.itm->get_local_id(DIM2);
399  const int bx = ctx.itm->get_local_range(DIM0);
400  const int by = ctx.itm->get_local_range(DIM1);
401 
402  const int tid = tx + bx * (ty + by * tz);
403 
404  if (tid < len) body(*(segment.begin() + tid));
405  }
406  }
407 };
408 
409 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
410 struct LoopExecute<sycl_flatten_group_local_loop<DIM0, DIM1, DIM2>, SEGMENT>
411 {
412  template<typename BODY>
413  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
414  SEGMENT const& segment,
415  BODY const& body)
416  {
417  const int len = segment.end() - segment.begin();
418 
419  const int tx = ctx.itm->get_local_id(DIM0);
420  const int ty = ctx.itm->get_local_id(DIM1);
421  const int tz = ctx.itm->get_local_id(DIM2);
422  const int bx = ctx.itm->get_local_range(DIM0);
423  const int by = ctx.itm->get_local_range(DIM1);
424  const int bz = ctx.itm->get_local_range(DIM2);
425 
426  for (int tid = tx + bx * (ty + by * tz); tid < len; tid += bx * by * bz)
427  {
428  body(*(segment.begin() + tid));
429  }
430  }
431 };
432 
433 /*
434  SYCL thread loops with block strides
435 */
436 template<typename SEGMENT, int DIM>
437 struct LoopExecute<sycl_local_012_loop<DIM>, SEGMENT>
438 {
439 
440  template<typename BODY>
441  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
442  SEGMENT const& segment,
443  BODY const& body)
444  {
445 
446  const int len = segment.end() - segment.begin();
447 
448  for (int tx = ctx.itm->get_local_id(DIM); tx < len;
449  tx += ctx.itm->get_local_range(DIM))
450  {
451  body(*(segment.begin() + tx));
452  }
453  }
454 };
455 
456 /*
457  SYCL thread direct mappings
458 */
459 template<typename SEGMENT, int DIM>
460 struct LoopExecute<sycl_local_012_direct<DIM>, SEGMENT>
461 {
462 
463  template<typename BODY>
464  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
465  SEGMENT const& segment,
466  BODY const& body)
467  {
468 
469  const int len = segment.end() - segment.begin();
470  {
471  const int tx = ctx.itm->get_local_id(DIM);
472  if (tx < len) body(*(segment.begin() + tx));
473  }
474  }
475 };
476 
477 /*
478  SYCL block loops with grid strides
479 */
480 template<typename SEGMENT, int DIM>
481 struct LoopExecute<sycl_group_012_loop<DIM>, SEGMENT>
482 {
483 
484  template<typename BODY>
485  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
486  SEGMENT const& segment,
487  BODY const& body)
488  {
489 
490  const int len = segment.end() - segment.begin();
491 
492  for (int bx = ctx.itm->get_group(DIM); bx < len;
493  bx += ctx.itm->get_group_range(DIM))
494  {
495  body(*(segment.begin() + bx));
496  }
497  }
498 };
499 
500 /*
501  SYCL block direct mappings
502 */
503 template<typename SEGMENT, int DIM>
504 struct LoopExecute<sycl_group_012_direct<DIM>, SEGMENT>
505 {
506 
507  template<typename BODY>
508  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
509  SEGMENT const& segment,
510  BODY const& body)
511  {
512 
513  const int len = segment.end() - segment.begin();
514  {
515  const int bx = ctx.itm->get_group(DIM);
516  if (bx < len) body(*(segment.begin() + bx));
517  }
518  }
519 };
520 
521 /*
522  SYCL thread loops with block strides + Return Index
523 */
524 template<typename SEGMENT, int DIM>
525 struct LoopICountExecute<sycl_local_012_loop<DIM>, SEGMENT>
526 {
527 
528  template<typename BODY>
529  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
530  SEGMENT const& segment,
531  BODY const& body)
532  {
533 
534  const int len = segment.end() - segment.begin();
535 
536  for (int tx = ctx.itm->get_local_id(DIM); tx < len;
537  tx += ctx.itm->get_local_range(DIM))
538  {
539  body(*(segment.begin() + tx), tx);
540  }
541  }
542 };
543 
544 /*
545  SYCL thread direct mappings
546 */
547 template<typename SEGMENT, int DIM>
548 struct LoopICountExecute<sycl_local_012_direct<DIM>, SEGMENT>
549 {
550 
551  template<typename BODY>
552  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
553  SEGMENT const& segment,
554  BODY const& body)
555  {
556 
557  const int len = segment.end() - segment.begin();
558  {
559  const int tx = ctx.itm->get_local_id(DIM);
560  if (tx < len) body(*(segment.begin() + tx), tx);
561  }
562  }
563 };
564 
565 /*
566  SYCL block loops with grid strides
567 */
568 template<typename SEGMENT, int DIM>
569 struct LoopICountExecute<sycl_group_012_loop<DIM>, SEGMENT>
570 {
571 
572  template<typename BODY>
573  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
574  SEGMENT const& segment,
575  BODY const& body)
576  {
577 
578  const int len = segment.end() - segment.begin();
579 
580  for (int bx = ctx.itm->get_group(DIM); bx < len;
581  bx += ctx.itm->get_group_range(DIM))
582  {
583  body(*(segment.begin() + bx), bx);
584  }
585  }
586 };
587 
588 /*
589  SYCL block direct mappings
590 */
591 template<typename SEGMENT, int DIM>
592 struct LoopICountExecute<sycl_group_012_direct<DIM>, SEGMENT>
593 {
594 
595  template<typename BODY>
596  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
597  SEGMENT const& segment,
598  BODY const& body)
599  {
600 
601  const int len = segment.end() - segment.begin();
602  {
603  const int bx = ctx.itm->get_group(DIM);
604  if (bx < len) body(*(segment.begin() + bx), bx);
605  }
606  }
607 };
608 
609 // perfectly nested sycl direct policies
610 using sycl_group_01_nested_direct = sycl_group_012_direct<0, 1>;
611 using sycl_group_02_nested_direct = sycl_group_012_direct<0, 2>;
612 using sycl_group_10_nested_direct = sycl_group_012_direct<1, 0>;
613 using sycl_group_12_nested_direct = sycl_group_012_direct<1, 2>;
614 using sycl_group_20_nested_direct = sycl_group_012_direct<2, 0>;
615 using sycl_group_21_nested_direct = sycl_group_012_direct<2, 1>;
616 
617 using sycl_group_012_nested_direct = sycl_group_012_direct<0, 1, 2>;
618 using sycl_group_021_nested_direct = sycl_group_012_direct<0, 2, 1>;
619 using sycl_group_102_nested_direct = sycl_group_012_direct<1, 0, 2>;
620 using sycl_group_120_nested_direct = sycl_group_012_direct<1, 2, 0>;
621 using sycl_group_201_nested_direct = sycl_group_012_direct<2, 0, 1>;
622 using sycl_group_210_nested_direct = sycl_group_012_direct<2, 1, 0>;
623 
624 template<typename SEGMENT, int DIM0, int DIM1>
625 struct LoopExecute<sycl_group_012_direct<DIM0, DIM1>, SEGMENT>
626 {
627 
628  template<typename BODY>
629  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
630  SEGMENT const& segment0,
631  SEGMENT const& segment1,
632  BODY const& body)
633  {
634  const int len1 = segment1.end() - segment1.begin();
635  const int len0 = segment0.end() - segment0.begin();
636  {
637  const int tx = ctx.itm->get_group(DIM0);
638  const int ty = ctx.itm->get_group(DIM1);
639  if (tx < len0 && ty < len1)
640  body(*(segment0.begin() + tx), *(segment1.begin() + ty));
641  }
642  }
643 };
644 
645 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
646 struct LoopExecute<sycl_group_012_direct<DIM0, DIM1, DIM2>, SEGMENT>
647 {
648 
649  template<typename BODY>
650  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
651  SEGMENT const& segment0,
652  SEGMENT const& segment1,
653  SEGMENT const& segment2,
654  BODY const& body)
655  {
656  const int len2 = segment2.end() - segment2.begin();
657  const int len1 = segment1.end() - segment1.begin();
658  const int len0 = segment0.end() - segment0.begin();
659  {
660  const int tx = ctx.itm->get_group(DIM0);
661  const int ty = ctx.itm->get_group(DIM1);
662  const int tz = ctx.itm->get_group(DIM2);
663  if (tx < len0 && ty < len1 && tz < len2)
664  body(*(segment0.begin() + tx), *(segment1.begin() + ty),
665  *(segment2.begin() + tz));
666  }
667  }
668 };
669 
670 /*
671  Perfectly nested sycl direct policies
672  Return local index
673 */
674 template<typename SEGMENT, int DIM0, int DIM1>
675 struct LoopICountExecute<sycl_group_012_direct<DIM0, DIM1>, SEGMENT>
676 {
677 
678  template<typename BODY>
679  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
680  SEGMENT const& segment0,
681  SEGMENT const& segment1,
682  BODY const& body)
683  {
684  const int len1 = segment1.end() - segment1.begin();
685  const int len0 = segment0.end() - segment0.begin();
686  {
687  const int tx = ctx.itm->get_group(DIM0);
688  const int ty = ctx.itm->get_group(DIM1);
689  if (tx < len0 && ty < len1)
690  body(*(segment0.begin() + tx), *(segment1.begin() + ty), tx, ty);
691  }
692  }
693 };
694 
695 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
696 struct LoopICountExecute<sycl_group_012_direct<DIM0, DIM1, DIM2>, SEGMENT>
697 {
698 
699  template<typename BODY>
700  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
701  SEGMENT const& segment0,
702  SEGMENT const& segment1,
703  SEGMENT const& segment2,
704  BODY const& body)
705  {
706  const int len2 = segment2.end() - segment2.begin();
707  const int len1 = segment1.end() - segment1.begin();
708  const int len0 = segment0.end() - segment0.begin();
709  {
710  const int tx = ctx.itm->get_group(DIM0);
711  const int ty = ctx.itm->get_group(DIM1);
712  const int tz = ctx.itm->get_group(DIM2);
713  if (tx < len0 && ty < len1 && tz < len2)
714  body(*(segment0.begin() + tx), *(segment1.begin() + ty),
715  *(segment2.begin() + tz), tx, ty, tz);
716  }
717  }
718 };
719 
720 // perfectly nested sycl loop policies
721 using sycl_group_01_nested_loop = sycl_group_012_loop<0, 1>;
722 using sycl_group_02_nested_loop = sycl_group_012_loop<0, 2>;
723 using sycl_group_10_nested_loop = sycl_group_012_loop<1, 0>;
724 using sycl_group_12_nested_loop = sycl_group_012_loop<1, 2>;
725 using sycl_group_20_nested_loop = sycl_group_012_loop<2, 0>;
726 using sycl_group_21_nested_loop = sycl_group_012_loop<2, 1>;
727 
728 using sycl_group_012_nested_loop = sycl_group_012_loop<0, 1, 2>;
729 using sycl_group_021_nested_loop = sycl_group_012_loop<0, 2, 1>;
730 using sycl_group_102_nested_loop = sycl_group_012_loop<1, 0, 2>;
731 using sycl_group_120_nested_loop = sycl_group_012_loop<1, 2, 0>;
732 using sycl_group_201_nested_loop = sycl_group_012_loop<2, 0, 1>;
733 using sycl_group_210_nested_loop = sycl_group_012_loop<2, 1, 0>;
734 
735 template<typename SEGMENT, int DIM0, int DIM1>
736 struct LoopExecute<sycl_group_012_loop<DIM0, DIM1>, SEGMENT>
737 {
738 
739  template<typename BODY>
740  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
741  SEGMENT const& segment0,
742  SEGMENT const& segment1,
743  BODY const& body)
744  {
745  const int len1 = segment1.end() - segment1.begin();
746  const int len0 = segment0.end() - segment0.begin();
747  {
748 
749  for (int bx = ctx.itm->get_group(DIM0); bx < len0;
750  bx += ctx.itm->get_group_range(DIM0))
751  {
752  for (int by = ctx.itm->get_group(DIM1); by < len1;
753  bx += ctx.itm->get_group_range(DIM1))
754  {
755  body(*(segment0.begin() + bx), *(segment1.begin() + by));
756  }
757  }
758  }
759  }
760 };
761 
762 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
763 struct LoopExecute<sycl_group_012_loop<DIM0, DIM1, DIM2>, SEGMENT>
764 {
765 
766  template<typename BODY>
767  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
768  SEGMENT const& segment0,
769  SEGMENT const& segment1,
770  SEGMENT const& segment2,
771  BODY const& body)
772  {
773  const int len2 = segment2.end() - segment2.begin();
774  const int len1 = segment1.end() - segment1.begin();
775  const int len0 = segment0.end() - segment0.begin();
776 
777  for (int bx = ctx.itm->get_group(DIM0); bx < len0;
778  bx += ctx.itm->get_group_range(DIM0))
779  {
780 
781  for (int by = ctx.itm->get_group(DIM1); by < len1;
782  by += ctx.itm->get_group_range(DIM1))
783  {
784 
785  for (int bz = ctx.itm->get_group(DIM2); bz < len2;
786  bz += ctx.itm->get_group_range(DIM2))
787  {
788 
789  body(*(segment0.begin() + bx), *(segment1.begin() + by),
790  *(segment2.begin() + bz));
791  }
792  }
793  }
794  }
795 };
796 
797 /*
798  perfectly nested sycl loop policies + returns local index
799 */
800 template<typename SEGMENT, int DIM0, int DIM1>
801 struct LoopICountExecute<sycl_group_012_loop<DIM0, DIM1>, SEGMENT>
802 {
803 
804  template<typename BODY>
805  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
806  SEGMENT const& segment0,
807  SEGMENT const& segment1,
808  BODY const& body)
809  {
810  const int len1 = segment1.end() - segment1.begin();
811  const int len0 = segment0.end() - segment0.begin();
812  {
813 
814  for (int bx = ctx.itm->get_group(DIM0); bx < len0;
815  bx += ctx.itm->get_group_range(DIM0))
816  {
817  for (int by = ctx.itm->get_group(DIM0); by < len1;
818  by += ctx.itm->get_group_range(DIM1))
819  {
820 
821  body(*(segment0.begin() + bx), *(segment1.begin() + by), bx, by);
822  }
823  }
824  }
825  }
826 };
827 
828 template<typename SEGMENT, int DIM0, int DIM1, int DIM2>
829 struct LoopICountExecute<sycl_group_012_loop<DIM0, DIM1, DIM2>, SEGMENT>
830 {
831 
832  template<typename BODY>
833  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
834  SEGMENT const& segment0,
835  SEGMENT const& segment1,
836  SEGMENT const& segment2,
837  BODY const& body)
838  {
839  const int len2 = segment2.end() - segment2.begin();
840  const int len1 = segment1.end() - segment1.begin();
841  const int len0 = segment0.end() - segment0.begin();
842 
843  for (int bx = ctx.itm->get_group(DIM0); bx < len0;
844  bx += ctx.itm->get_group_range(DIM0))
845  {
846 
847  for (int by = ctx.itm->get_group(DIM0); by < len1;
848  by += ctx.itm->get_group_range(DIM0))
849  {
850 
851  for (int bz = ctx.itm->get_group(DIM0); bz < len2;
852  bz += ctx.itm->get_group_range(DIM0))
853  {
854 
855  body(*(segment0.begin() + bx), *(segment1.begin() + by),
856  *(segment2.begin() + bz), bx, by, bz);
857  }
858  }
859  }
860  }
861 };
862 
863 template<typename SEGMENT, int DIM>
864 struct TileExecute<sycl_local_012_loop<DIM>, SEGMENT>
865 {
866 
867  template<typename TILE_T, typename BODY>
868  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
869  TILE_T tile_size,
870  SEGMENT const& segment,
871  BODY const& body)
872  {
873 
874  const int len = segment.end() - segment.begin();
875 
876  for (int tx = ctx.itm->get_local_id(DIM) * tile_size; tx < len;
877  tx += ctx.itm->get_local_range(DIM) * tile_size)
878  {
879  body(segment.slice(tx, tile_size));
880  }
881  }
882 };
883 
884 template<typename SEGMENT, int DIM>
885 struct TileExecute<sycl_local_012_direct<DIM>, SEGMENT>
886 {
887 
888  template<typename TILE_T, typename BODY>
889  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
890  TILE_T tile_size,
891  SEGMENT const& segment,
892  BODY const& body)
893  {
894 
895  const int len = segment.end() - segment.begin();
896 
897  int tx = ctx.itm->get_local_id(DIM) * tile_size;
898  if (tx < len)
899  {
900  body(segment.slice(tx, tile_size));
901  }
902  }
903 };
904 
905 template<typename SEGMENT, int DIM>
906 struct TileExecute<sycl_group_012_loop<DIM>, SEGMENT>
907 {
908 
909  template<typename TILE_T, typename BODY>
910  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
911  TILE_T tile_size,
912  SEGMENT const& segment,
913  BODY const& body)
914  {
915 
916  const int len = segment.end() - segment.begin();
917 
918  for (int tx = ctx.itm->get_group(DIM) * tile_size;
919 
920  tx < len;
921 
922  tx += ctx.itm->get_group_range(DIM) * tile_size)
923  {
924  body(segment.slice(tx, tile_size));
925  }
926  }
927 };
928 
929 template<typename SEGMENT, int DIM>
930 struct TileExecute<sycl_group_012_direct<DIM>, SEGMENT>
931 {
932 
933  template<typename TILE_T, typename BODY>
934  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
935  TILE_T tile_size,
936  SEGMENT const& segment,
937  BODY const& body)
938  {
939 
940  const int len = segment.end() - segment.begin();
941 
942  int tx = ctx.itm->get_group(DIM) * tile_size;
943  if (tx < len)
944  {
945  body(segment.slice(tx, tile_size));
946  }
947  }
948 };
949 
950 // Tile execute + return index
951 template<typename SEGMENT, int DIM>
952 struct TileTCountExecute<sycl_local_012_loop<DIM>, SEGMENT>
953 {
954 
955  template<typename TILE_T, typename BODY>
956  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
957  TILE_T tile_size,
958  SEGMENT const& segment,
959  BODY const& body)
960  {
961 
962  const int len = segment.end() - segment.begin();
963 
964  for (int tx = ctx.itm->get_local_id(DIM) * tile_size; tx < len;
965  tx += ctx.itm->get_local_range(DIM) * tile_size)
966  {
967  body(segment.slice(tx, tile_size), tx / tile_size);
968  }
969  }
970 };
971 
972 template<typename SEGMENT, int DIM>
973 struct TileTCountExecute<sycl_local_012_direct<DIM>, SEGMENT>
974 {
975 
976  template<typename TILE_T, typename BODY>
977  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
978  TILE_T tile_size,
979  SEGMENT const& segment,
980  BODY const& body)
981  {
982 
983  const int len = segment.end() - segment.begin();
984 
985  int tx = ctx.itm->get_local_id(DIM) * tile_size;
986  if (tx < len)
987  {
988  body(segment.slice(tx, tile_size), tx / tile_size);
989  }
990  }
991 };
992 
993 template<typename SEGMENT, int DIM>
994 struct TileTCountExecute<sycl_group_012_loop<DIM>, SEGMENT>
995 {
996 
997  template<typename TILE_T, typename BODY>
998  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
999  TILE_T tile_size,
1000  SEGMENT const& segment,
1001  BODY const& body)
1002  {
1003 
1004  const int len = segment.end() - segment.begin();
1005 
1006  for (int bx = ctx.itm->get_group(DIM) * tile_size; bx < len;
1007  bx += ctx.itm->get_group_range(DIM) * tile_size)
1008  {
1009  body(segment.slice(bx, tile_size), bx / tile_size);
1010  }
1011  }
1012 };
1013 
1014 template<typename SEGMENT, int DIM>
1015 struct TileTCountExecute<sycl_group_012_direct<DIM>, SEGMENT>
1016 {
1017 
1018  template<typename TILE_T, typename BODY>
1019  static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const& ctx,
1020  TILE_T tile_size,
1021  SEGMENT const& segment,
1022  BODY const& body)
1023  {
1024 
1025  const int len = segment.end() - segment.begin();
1026 
1027  int bx = ctx.itm->get_group(DIM) * tile_size;
1028  if (bx < len)
1029  {
1030  body(segment.slice(bx, tile_size), bx / tile_size);
1031  }
1032  }
1033 };
1034 
1035 } // namespace RAJA
1036 #endif
Header file defining prototypes for routines used to manage memory for SYCL reductions and other oper...
RAJA header file containing the core components of RAJA::launch.
#define RAJA_DEVICE
Definition: macros.hpp:66
constexpr RAJA_HOST_DEVICE auto invoke_body(Params &&params, Fn &&f, Ts &&... extra)
Definition: forall.hpp:598
Definition: AlignedRangeIndexSetBuilders.cpp:35
LaunchContextType ctx
Definition: launch.hpp:185
sycl_group_012_loop< 1, 0, 2 > sycl_group_102_nested_loop
Definition: launch.hpp:730
sycl_group_012_loop< 2, 0, 1 > sycl_group_201_nested_loop
Definition: launch.hpp:732
sycl_group_012_loop< 0, 1 > sycl_group_01_nested_loop
Definition: launch.hpp:721
sycl_group_012_direct< 2, 1 > sycl_group_21_nested_direct
Definition: launch.hpp:615
sycl_group_012_direct< 1, 2, 0 > sycl_group_120_nested_direct
Definition: launch.hpp:620
sycl_group_012_loop< 1, 2 > sycl_group_12_nested_loop
Definition: launch.hpp:724
sycl_group_012_loop< 0, 2, 1 > sycl_group_021_nested_loop
Definition: launch.hpp:729
sycl_group_012_direct< 2, 0 > sycl_group_20_nested_direct
Definition: launch.hpp:614
sycl_group_012_loop< 2, 0 > sycl_group_20_nested_loop
Definition: launch.hpp:725
sycl_group_012_loop< 2, 1, 0 > sycl_group_210_nested_loop
Definition: launch.hpp:733
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
sycl_group_012_loop< 1, 2, 0 > sycl_group_120_nested_loop
Definition: launch.hpp:731
sycl_group_012_loop< 1, 0 > sycl_group_10_nested_loop
Definition: launch.hpp:723
auto & body
Definition: launch.hpp:177
sycl_group_012_direct< 0, 2, 1 > sycl_group_021_nested_direct
Definition: launch.hpp:618
sycl_group_012_direct< 2, 0, 1 > sycl_group_201_nested_direct
Definition: launch.hpp:621
sycl_group_012_loop< 0, 1, 2 > sycl_group_012_nested_loop
Definition: launch.hpp:728
sycl_group_012_loop< 0, 2 > sycl_group_02_nested_loop
Definition: launch.hpp:722
sycl_group_012_direct< 0, 1 > sycl_group_01_nested_direct
Definition: launch.hpp:610
sycl_group_012_direct< 1, 2 > sycl_group_12_nested_direct
Definition: launch.hpp:613
sycl_group_012_loop< 2, 1 > sycl_group_21_nested_loop
Definition: launch.hpp:726
sycl_group_012_direct< 0, 1, 2 > sycl_group_012_nested_direct
Definition: launch.hpp:617
sycl_group_012_direct< 1, 0 > sycl_group_10_nested_direct
Definition: launch.hpp:612
sycl_group_012_direct< 2, 1, 0 > sycl_group_210_nested_direct
Definition: launch.hpp:622
sycl_group_012_direct< 1, 0, 2 > sycl_group_102_nested_direct
Definition: launch.hpp:619
sycl_group_012_direct< 0, 2 > sycl_group_02_nested_direct
Definition: launch.hpp:611
Header file for RAJA resource definitions.
static concepts::enable_if_t< resources::EventProxy< resources::Resource >, RAJA::expt::type_traits::is_ForallParamPack< ReduceParams > > exec(RAJA::resources::Resource res, const LaunchParams &launch_params, LoopBody &&loop_body, ReduceParams launch_reducers)
Definition: launch.hpp:40
Definition: launch_core.hpp:263
Definition: launch_core.hpp:163
size_t shared_mem_size
Definition: launch_core.hpp:167
Teams teams
Definition: launch_core.hpp:165
Threads threads
Definition: launch_core.hpp:166
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:390
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:346
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:413
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:367
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, SEGMENT const &segment2, BODY const &body)
Definition: launch.hpp:259
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, BODY const &body)
Definition: launch.hpp:226
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:199
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, SEGMENT const &segment2, BODY const &body)
Definition: launch.hpp:650
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, BODY const &body)
Definition: launch.hpp:629
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:508
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, SEGMENT const &segment2, BODY const &body)
Definition: launch.hpp:767
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, BODY const &body)
Definition: launch.hpp:740
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:485
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:464
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:441
Definition: launch_core.hpp:480
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, SEGMENT const &segment2, BODY const &body)
Definition: launch.hpp:700
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, BODY const &body)
Definition: launch.hpp:679
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:596
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, SEGMENT const &segment2, BODY const &body)
Definition: launch.hpp:833
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment0, SEGMENT const &segment1, BODY const &body)
Definition: launch.hpp:805
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:573
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:552
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:529
Definition: launch_core.hpp:483
int value[3]
Definition: launch_core.hpp:99
int value[3]
Definition: launch_core.hpp:124
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:934
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:910
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:889
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:868
Definition: launch_core.hpp:579
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:1019
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:998
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:977
static RAJA_INLINE RAJA_DEVICE void exec(LaunchContext const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch.hpp:956
Definition: launch_core.hpp:582
static constexpr void parampack_resolve(EXEC_POL const &pol, ForallParamPack< Params... > &f_params, Args &&... args)
Definition: forall.hpp:304
static constexpr void parampack_init(EXEC_POL const &pol, ForallParamPack< Params... > &f_params, Args &&... args)
Definition: forall.hpp:269
static RAJA_HOST_DEVICE constexpr void parampack_combine(EXEC_POL const &pol, ForallParamPack< Params... > &f_params, Args &&... args)
Definition: forall.hpp:286
Definition: TypeTraits.hpp:59
Definition: launch.hpp:290
Definition: launch.hpp:320
Definition: launch.hpp:188
Header file containing RAJA SYCL policy definitions.