RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
scan.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_scan_hip_HPP
21 #define RAJA_scan_hip_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #if defined(RAJA_ENABLE_HIP)
26 
27 #include <iterator>
28 #include <type_traits>
29 
30 #if defined(__HIPCC__)
31 // Tell rocprim to provide its HIP API
32 #define ROCPRIM_HIP_API 1
33 #include "rocprim/device/device_scan.hpp"
34 #elif defined(__CUDACC__)
35 #include "cub/device/device_scan.cuh"
36 #include "cub/util_allocator.cuh"
37 #endif
38 
41 
42 namespace RAJA
43 {
44 namespace impl
45 {
46 namespace scan
47 {
48 
53 template<typename IterationMapping,
54  typename IterationGetter,
55  typename Concretizer,
56  bool Async,
57  typename InputIter,
58  typename Function>
59 RAJA_INLINE resources::EventProxy<resources::Hip> inclusive_inplace(
60  resources::Hip hip_res,
61  ::RAJA::policy::hip::
62  hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
63  InputIter begin,
64  InputIter end,
65  Function binary_op)
66 {
67  hipStream_t stream = hip_res.get_stream();
68 
69  int len = std::distance(begin, end);
70  // Determine temporary device storage requirements
71  void* d_temp_storage = nullptr;
72  size_t temp_storage_bytes = 0;
73 #if defined(__HIPCC__)
74  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::inclusive_scan, d_temp_storage,
75  temp_storage_bytes, begin, begin, len,
76  binary_op, stream);
77 #elif defined(__CUDACC__)
78  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::InclusiveScan,
79  d_temp_storage, temp_storage_bytes, begin,
80  begin, binary_op, len, stream);
81 #endif
82 
83  // Allocate temporary storage
84  d_temp_storage =
85  hip::device_mempool_type::getInstance().malloc<unsigned char>(
86  temp_storage_bytes);
87  // Run
88 #if defined(__HIPCC__)
89  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::inclusive_scan, d_temp_storage,
90  temp_storage_bytes, begin, begin, len,
91  binary_op, stream);
92 #elif defined(__CUDACC__)
93  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::InclusiveScan,
94  d_temp_storage, temp_storage_bytes, begin,
95  begin, binary_op, len, stream);
96 #endif
97  // Free temporary storage
98  hip::device_mempool_type::getInstance().free(d_temp_storage);
99 
100  hip::launch(hip_res, Async);
101 
102  return resources::EventProxy<resources::Hip>(hip_res);
103 }
104 
109 template<typename IterationMapping,
110  typename IterationGetter,
111  typename Concretizer,
112  bool Async,
113  typename InputIter,
114  typename Function,
115  typename T>
116 RAJA_INLINE resources::EventProxy<resources::Hip> exclusive_inplace(
117  resources::Hip hip_res,
118  ::RAJA::policy::hip::
119  hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
120  InputIter begin,
121  InputIter end,
122  Function binary_op,
123  T init)
124 {
125  hipStream_t stream = hip_res.get_stream();
126 
127  int len = std::distance(begin, end);
128  // Determine temporary device storage requirements
129  void* d_temp_storage = nullptr;
130  size_t temp_storage_bytes = 0;
131 #if defined(__HIPCC__)
132  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::exclusive_scan, d_temp_storage,
133  temp_storage_bytes, begin, begin, init, len,
134  binary_op, stream);
135 #elif defined(__CUDACC__)
136  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::ExclusiveScan,
137  d_temp_storage, temp_storage_bytes, begin,
138  begin, binary_op, init, len, stream);
139 #endif
140  // Allocate temporary storage
141  d_temp_storage =
142  hip::device_mempool_type::getInstance().malloc<unsigned char>(
143  temp_storage_bytes);
144  // Run
145 #if defined(__HIPCC__)
146  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::exclusive_scan, d_temp_storage,
147  temp_storage_bytes, begin, begin, init, len,
148  binary_op, stream);
149 #elif defined(__CUDACC__)
150  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::ExclusiveScan,
151  d_temp_storage, temp_storage_bytes, begin,
152  begin, binary_op, init, len, stream);
153 #endif
154  // Free temporary storage
155  hip::device_mempool_type::getInstance().free(d_temp_storage);
156 
157  hip::launch(hip_res, Async);
158 
159  return resources::EventProxy<resources::Hip>(hip_res);
160 }
161 
166 template<typename IterationMapping,
167  typename IterationGetter,
168  typename Concretizer,
169  bool Async,
170  typename InputIter,
171  typename OutputIter,
172  typename Function>
173 RAJA_INLINE resources::EventProxy<resources::Hip> inclusive(
174  resources::Hip hip_res,
175  ::RAJA::policy::hip::
176  hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
177  InputIter begin,
178  InputIter end,
179  OutputIter out,
180  Function binary_op)
181 {
182  hipStream_t stream = hip_res.get_stream();
183 
184  int len = std::distance(begin, end);
185  // Determine temporary device storage requirements
186  void* d_temp_storage = nullptr;
187  size_t temp_storage_bytes = 0;
188 #if defined(__HIPCC__)
189  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::inclusive_scan, d_temp_storage,
190  temp_storage_bytes, begin, out, len, binary_op,
191  stream);
192 #elif defined(__CUDACC__)
193  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::InclusiveScan,
194  d_temp_storage, temp_storage_bytes, begin, out,
195  binary_op, len, stream);
196 #endif
197  // Allocate temporary storage
198  d_temp_storage =
199  hip::device_mempool_type::getInstance().malloc<unsigned char>(
200  temp_storage_bytes);
201  // Run
202 #if defined(__HIPCC__)
203  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::inclusive_scan, d_temp_storage,
204  temp_storage_bytes, begin, out, len, binary_op,
205  stream);
206 #elif defined(__CUDACC__)
207  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::InclusiveScan,
208  d_temp_storage, temp_storage_bytes, begin, out,
209  binary_op, len, stream);
210 #endif
211  // Free temporary storage
212  hip::device_mempool_type::getInstance().free(d_temp_storage);
213 
214  hip::launch(hip_res, Async);
215 
216  return resources::EventProxy<resources::Hip>(hip_res);
217 }
218 
223 template<typename IterationMapping,
224  typename IterationGetter,
225  typename Concretizer,
226  bool Async,
227  typename InputIter,
228  typename OutputIter,
229  typename Function,
230  typename T>
231 RAJA_INLINE resources::EventProxy<resources::Hip> exclusive(
232  resources::Hip hip_res,
233  ::RAJA::policy::hip::
234  hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
235  InputIter begin,
236  InputIter end,
237  OutputIter out,
238  Function binary_op,
239  T init)
240 {
241  hipStream_t stream = hip_res.get_stream();
242 
243  int len = std::distance(begin, end);
244  // Determine temporary device storage requirements
245  void* d_temp_storage = nullptr;
246  size_t temp_storage_bytes = 0;
247 #if defined(__HIPCC__)
248  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::exclusive_scan, d_temp_storage,
249  temp_storage_bytes, begin, out, init, len,
250  binary_op, stream);
251 #elif defined(__CUDACC__)
252  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::ExclusiveScan,
253  d_temp_storage, temp_storage_bytes, begin, out,
254  binary_op, init, len, stream);
255 #endif
256  // Allocate temporary storage
257  d_temp_storage =
258  hip::device_mempool_type::getInstance().malloc<unsigned char>(
259  temp_storage_bytes);
260  // Run
261 #if defined(__HIPCC__)
262  CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::exclusive_scan, d_temp_storage,
263  temp_storage_bytes, begin, out, init, len,
264  binary_op, stream);
265 #elif defined(__CUDACC__)
266  CAMP_HIP_API_INVOKE_AND_CHECK(::cub::DeviceScan::ExclusiveScan,
267  d_temp_storage, temp_storage_bytes, begin, out,
268  binary_op, init, len, stream);
269 #endif
270  // Free temporary storage
271  hip::device_mempool_type::getInstance().free(d_temp_storage);
272 
273  hip::launch(hip_res, Async);
274 
275  return resources::EventProxy<resources::Hip>(hip_res);
276 }
277 
278 } // namespace scan
279 
280 } // namespace impl
281 
282 } // namespace RAJA
283 
284 #endif // closing endif for RAJA_ENABLE_HIP guard
285 
286 #endif // closing endif for header file include guard
Header file defining prototypes for routines used to manage memory for HIP reductions and other opera...
Header file containing RAJA HIP policy definitions.
RAJA_INLINE concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< Policy > > inclusive(resources::Host host_res, const Policy &exec, Iter begin, Iter end, OutIter out, BinFn f)
Definition: scan.hpp:144
RAJA_INLINE concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< Policy > > exclusive(resources::Host host_res, const Policy &exec, Iter begin, Iter end, OutIter out, BinFn f, ValueT v)
Definition: scan.hpp:167
RAJA_INLINE concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< Policy > > inclusive_inplace(resources::Host host_res, const Policy &, Iter begin, Iter end, BinFn f)
Definition: scan.hpp:51
RAJA_INLINE concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< Policy > > exclusive_inplace(resources::Host host_res, const Policy &, Iter begin, Iter end, BinFn f, ValueT v)
Definition: scan.hpp:96
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_INLINE concepts::enable_if_t< resources::EventProxy< Res >, type_traits::is_execution_policy< ExecPolicy >, type_traits::is_resource< Res > > inclusive_scan(Res r, Args &&... args)
Definition: scan.hpp:381
void launch(LaunchParams const &launch_params, ReduceParams &&... rest_of_launch_args)
Definition: launch_core.hpp:268
RAJA_INLINE concepts::enable_if_t< resources::EventProxy< Res >, type_traits::is_execution_policy< ExecPolicy >, type_traits::is_resource< Res > > exclusive_scan(Res r, Args &&... args)
Definition: scan.hpp:352