20 #ifndef RAJA_sort_hip_HPP
21 #define RAJA_sort_hip_HPP
23 #include "RAJA/config.hpp"
25 #if defined(RAJA_ENABLE_HIP)
29 #include <type_traits>
31 #if defined(__HIPCC__)
33 #define ROCPRIM_HIP_API 1
34 #include "rocprim/device/device_transform.hpp"
35 #include "rocprim/device/device_radix_sort.hpp"
36 #elif defined(__CUDACC__)
37 #include "cub/device/device_radix_sort.cuh"
56 #if defined(__HIPCC__)
58 using double_buffer = ::rocprim::double_buffer<R>;
59 #elif defined(__CUDACC__)
61 using double_buffer = ::cub::DoubleBuffer<R>;
65 R* get_current(double_buffer<R>& d_bufs)
67 #if defined(__HIPCC__)
68 return d_bufs.current();
69 #elif defined(__CUDACC__)
70 return d_bufs.Current();
79 template<
typename IterationMapping,
80 typename IterationGetter,
85 concepts::enable_if_t<
86 resources::EventProxy<resources::Hip>,
87 concepts::negate<concepts::all_of<
88 type_traits::is_arithmetic<RAJA::detail::IterVal<Iter>>,
89 std::is_pointer<Iter>,
91 camp::is_same<Compare,
92 operators::less<RAJA::detail::IterVal<Iter>>>,
93 camp::is_same<Compare,
94 operators::greater<RAJA::detail::IterVal<Iter>>>>>>>
95 stable(resources::Hip hip_res,
97 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
105 std::is_pointer<Iter>,
107 camp::is_same<Compare,
109 camp::is_same<Compare, operators::greater<
111 "RAJA stable_sort<hip_exec> is only implemented for pointers to "
112 "arithmetic types and RAJA::operators::less and "
113 "RAJA::operators::greater.");
115 return resources::EventProxy<resources::Hip>(hip_res);
121 template<
typename IterationMapping,
122 typename IterationGetter,
123 typename Concretizer,
126 concepts::enable_if_t<resources::EventProxy<resources::Hip>,
127 type_traits::is_arithmetic<RAJA::detail::IterVal<Iter>>,
128 std::is_pointer<Iter>>
129 stable(resources::Hip hip_res,
130 ::RAJA::policy::hip::
131 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
136 hipStream_t stream = hip_res.get_stream();
140 int len = std::distance(begin, end);
142 int end_bit =
sizeof(R) * CHAR_BIT;
145 R* d_out = hip::device_mempool_type::getInstance().malloc<R>(len);
149 detail::double_buffer<R> d_keys(begin, d_out);
152 void* d_temp_storage =
nullptr;
153 size_t temp_storage_bytes = 0;
154 #if defined(__HIPCC__)
155 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_keys, d_temp_storage,
156 temp_storage_bytes, d_keys, len, begin_bit,
158 #elif defined(__CUDACC__)
159 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortKeys,
160 d_temp_storage, temp_storage_bytes, d_keys,
161 len, begin_bit, end_bit, stream);
165 hip::device_mempool_type::getInstance().malloc<
unsigned char>(
169 #if defined(__HIPCC__)
170 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_keys, d_temp_storage,
171 temp_storage_bytes, d_keys, len, begin_bit,
173 #elif defined(__CUDACC__)
174 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortKeys,
175 d_temp_storage, temp_storage_bytes, d_keys,
176 len, begin_bit, end_bit, stream);
179 hip::device_mempool_type::getInstance().free(d_temp_storage);
181 if (detail::get_current(d_keys) == d_out)
185 CAMP_HIP_API_INVOKE_AND_CHECK(hipMemcpyAsync, begin, d_out, len *
sizeof(R),
186 hipMemcpyDefault, stream);
189 hip::device_mempool_type::getInstance().free(d_out);
193 return resources::EventProxy<resources::Hip>(hip_res);
199 template<
typename IterationMapping,
200 typename IterationGetter,
201 typename Concretizer,
204 concepts::enable_if_t<resources::EventProxy<resources::Hip>,
205 type_traits::is_arithmetic<RAJA::detail::IterVal<Iter>>,
206 std::is_pointer<Iter>>
207 stable(resources::Hip hip_res,
208 ::RAJA::policy::hip::
209 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
214 hipStream_t stream = hip_res.get_stream();
218 int len = std::distance(begin, end);
220 int end_bit =
sizeof(R) * CHAR_BIT;
223 R* d_out = hip::device_mempool_type::getInstance().malloc<R>(len);
227 detail::double_buffer<R> d_keys(begin, d_out);
230 void* d_temp_storage =
nullptr;
231 size_t temp_storage_bytes = 0;
232 #if defined(__HIPCC__)
233 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_keys_desc, d_temp_storage,
234 temp_storage_bytes, d_keys, len, begin_bit,
236 #elif defined(__CUDACC__)
237 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortKeysDescending,
238 d_temp_storage, temp_storage_bytes, d_keys,
239 len, begin_bit, end_bit, stream);
243 hip::device_mempool_type::getInstance().malloc<
unsigned char>(
247 #if defined(__HIPCC__)
248 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_keys_desc, d_temp_storage,
249 temp_storage_bytes, d_keys, len, begin_bit,
251 #elif defined(__CUDACC__)
252 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortKeysDescending,
253 d_temp_storage, temp_storage_bytes, d_keys,
254 len, begin_bit, end_bit, stream);
257 hip::device_mempool_type::getInstance().free(d_temp_storage);
259 if (detail::get_current(d_keys) == d_out)
263 CAMP_HIP_API_INVOKE_AND_CHECK(hipMemcpyAsync, begin, d_out, len *
sizeof(R),
264 hipMemcpyDefault, stream);
267 hip::device_mempool_type::getInstance().free(d_out);
271 return resources::EventProxy<resources::Hip>(hip_res);
277 template<
typename IterationMapping,
278 typename IterationGetter,
279 typename Concretizer,
283 concepts::enable_if_t<
284 resources::EventProxy<resources::Hip>,
285 concepts::negate<concepts::all_of<
286 type_traits::is_arithmetic<RAJA::detail::IterVal<Iter>>,
287 std::is_pointer<Iter>,
289 camp::is_same<Compare,
290 operators::less<RAJA::detail::IterVal<Iter>>>,
291 camp::is_same<Compare,
292 operators::greater<RAJA::detail::IterVal<Iter>>>>>>>
294 ::RAJA::policy::hip::
295 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
303 std::is_pointer<Iter>,
305 camp::is_same<Compare,
307 camp::is_same<Compare, operators::greater<
309 "RAJA sort<hip_exec> is only implemented for pointers to arithmetic "
310 "types and RAJA::operators::less and RAJA::operators::greater.");
312 return resources::EventProxy<resources::Hip>(hip_res);
318 template<
typename IterationMapping,
319 typename IterationGetter,
320 typename Concretizer,
323 concepts::enable_if_t<resources::EventProxy<resources::Hip>,
324 type_traits::is_arithmetic<RAJA::detail::IterVal<Iter>>,
325 std::is_pointer<Iter>>
327 ::RAJA::policy::hip::
328 hip_exec<IterationMapping, IterationGetter, Concretizer, Async> p,
333 return stable(hip_res, p, begin, end, comp);
339 template<
typename IterationMapping,
340 typename IterationGetter,
341 typename Concretizer,
344 concepts::enable_if_t<resources::EventProxy<resources::Hip>,
345 type_traits::is_arithmetic<RAJA::detail::IterVal<Iter>>,
346 std::is_pointer<Iter>>
348 ::RAJA::policy::hip::
349 hip_exec<IterationMapping, IterationGetter, Concretizer, Async> p,
354 return stable(hip_res, p, begin, end, comp);
360 template<
typename IterationMapping,
361 typename IterationGetter,
362 typename Concretizer,
367 concepts::enable_if_t<
368 resources::EventProxy<resources::Hip>,
369 concepts::negate<concepts::all_of<
370 type_traits::is_arithmetic<RAJA::detail::IterVal<KeyIter>>,
371 std::is_pointer<KeyIter>,
372 std::is_pointer<ValIter>,
374 camp::is_same<Compare,
375 operators::less<RAJA::detail::IterVal<KeyIter>>>,
378 operators::greater<RAJA::detail::IterVal<KeyIter>>>>>>>
380 resources::Hip hip_res,
381 ::RAJA::policy::hip::
382 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
388 static_assert(std::is_pointer<KeyIter>::value,
389 "stable_sort_pairs<hip_exec> is only implemented for pointers");
390 static_assert(std::is_pointer<ValIter>::value,
391 "stable_sort_pairs<hip_exec> is only implemented for pointers");
394 type_traits::is_arithmetic<K>::value,
395 "stable_sort_pairs<hip_exec> is only implemented for arithmetic types");
397 concepts::any_of<camp::is_same<Compare, operators::less<K>>,
398 camp::is_same<Compare, operators::greater<K>>>::value,
399 "stable_sort_pairs<hip_exec> is only implemented for "
400 "RAJA::operators::less or RAJA::operators::greater");
402 return resources::EventProxy<resources::Hip>(hip_res);
408 template<
typename IterationMapping,
409 typename IterationGetter,
410 typename Concretizer,
414 concepts::enable_if_t<
415 resources::EventProxy<resources::Hip>,
416 type_traits::is_arithmetic<RAJA::detail::IterVal<KeyIter>>,
417 std::is_pointer<KeyIter>,
418 std::is_pointer<ValIter>>
420 resources::Hip hip_res,
421 ::RAJA::policy::hip::
422 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
428 hipStream_t stream = hip_res.get_stream();
433 int len = std::distance(keys_begin, keys_end);
435 int end_bit =
sizeof(K) * CHAR_BIT;
438 K* d_keys_out = hip::device_mempool_type::getInstance().malloc<K>(len);
439 V* d_vals_out = hip::device_mempool_type::getInstance().malloc<V>(len);
443 detail::double_buffer<K> d_keys(keys_begin, d_keys_out);
444 detail::double_buffer<V> d_vals(vals_begin, d_vals_out);
447 void* d_temp_storage =
nullptr;
448 size_t temp_storage_bytes = 0;
449 #if defined(__HIPCC__)
450 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_pairs, d_temp_storage,
451 temp_storage_bytes, d_keys, d_vals, len,
452 begin_bit, end_bit, stream);
453 #elif defined(__CUDACC__)
454 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortPairs,
455 d_temp_storage, temp_storage_bytes, d_keys,
456 d_vals, len, begin_bit, end_bit, stream);
460 hip::device_mempool_type::getInstance().malloc<
unsigned char>(
464 #if defined(__HIPCC__)
465 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_pairs, d_temp_storage,
466 temp_storage_bytes, d_keys, d_vals, len,
467 begin_bit, end_bit, stream);
468 #elif defined(__CUDACC__)
469 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortPairs,
470 d_temp_storage, temp_storage_bytes, d_keys,
471 d_vals, len, begin_bit, end_bit, stream);
474 hip::device_mempool_type::getInstance().free(d_temp_storage);
476 if (detail::get_current(d_keys) == d_keys_out)
480 CAMP_HIP_API_INVOKE_AND_CHECK(hipMemcpyAsync, keys_begin, d_keys_out,
481 len *
sizeof(K), hipMemcpyDefault, stream);
483 if (detail::get_current(d_vals) == d_vals_out)
487 CAMP_HIP_API_INVOKE_AND_CHECK(hipMemcpyAsync, vals_begin, d_vals_out,
488 len *
sizeof(V), hipMemcpyDefault, stream);
491 hip::device_mempool_type::getInstance().free(d_keys_out);
492 hip::device_mempool_type::getInstance().free(d_vals_out);
496 return resources::EventProxy<resources::Hip>(hip_res);
502 template<
typename IterationMapping,
503 typename IterationGetter,
504 typename Concretizer,
508 concepts::enable_if_t<
509 resources::EventProxy<resources::Hip>,
510 type_traits::is_arithmetic<RAJA::detail::IterVal<KeyIter>>,
511 std::is_pointer<KeyIter>,
512 std::is_pointer<ValIter>>
514 resources::Hip hip_res,
515 ::RAJA::policy::hip::
516 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
522 hipStream_t stream = hip_res.get_stream();
527 int len = std::distance(keys_begin, keys_end);
529 int end_bit =
sizeof(K) * CHAR_BIT;
532 K* d_keys_out = hip::device_mempool_type::getInstance().malloc<K>(len);
533 V* d_vals_out = hip::device_mempool_type::getInstance().malloc<V>(len);
537 detail::double_buffer<K> d_keys(keys_begin, d_keys_out);
538 detail::double_buffer<V> d_vals(vals_begin, d_vals_out);
541 void* d_temp_storage =
nullptr;
542 size_t temp_storage_bytes = 0;
543 #if defined(__HIPCC__)
544 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_pairs_desc,
545 d_temp_storage, temp_storage_bytes, d_keys,
546 d_vals, len, begin_bit, end_bit, stream);
547 #elif defined(__CUDACC__)
548 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortPairsDescending,
549 d_temp_storage, temp_storage_bytes, d_keys,
550 d_vals, len, begin_bit, end_bit, stream);
554 hip::device_mempool_type::getInstance().malloc<
unsigned char>(
558 #if defined(__HIPCC__)
559 CAMP_HIP_API_INVOKE_AND_CHECK(::rocprim::radix_sort_pairs_desc,
560 d_temp_storage, temp_storage_bytes, d_keys,
561 d_vals, len, begin_bit, end_bit, stream);
562 #elif defined(__CUDACC__)
563 CAMP_CUDA_API_INVOKE_AND_CHECK(::cub::DeviceRadixSort::SortPairsDescending,
564 d_temp_storage, temp_storage_bytes, d_keys,
565 d_vals, len, begin_bit, end_bit, stream);
568 hip::device_mempool_type::getInstance().free(d_temp_storage);
570 if (detail::get_current(d_keys) == d_keys_out)
574 CAMP_HIP_API_INVOKE_AND_CHECK(hipMemcpyAsync, keys_begin, d_keys_out,
575 len *
sizeof(K), hipMemcpyDefault, stream);
577 if (detail::get_current(d_vals) == d_vals_out)
581 CAMP_HIP_API_INVOKE_AND_CHECK(hipMemcpyAsync, vals_begin, d_vals_out,
582 len *
sizeof(V), hipMemcpyDefault, stream);
585 hip::device_mempool_type::getInstance().free(d_keys_out);
586 hip::device_mempool_type::getInstance().free(d_vals_out);
590 return resources::EventProxy<resources::Hip>(hip_res);
596 template<
typename IterationMapping,
597 typename IterationGetter,
598 typename Concretizer,
603 concepts::enable_if_t<
604 resources::EventProxy<resources::Hip>,
605 concepts::negate<concepts::all_of<
606 type_traits::is_arithmetic<RAJA::detail::IterVal<KeyIter>>,
607 std::is_pointer<KeyIter>,
608 std::is_pointer<ValIter>,
610 camp::is_same<Compare,
611 operators::less<RAJA::detail::IterVal<KeyIter>>>,
614 operators::greater<RAJA::detail::IterVal<KeyIter>>>>>>>
616 resources::Hip hip_res,
617 ::RAJA::policy::hip::
618 hip_exec<IterationMapping, IterationGetter, Concretizer, Async>,
624 static_assert(std::is_pointer<KeyIter>::value,
625 "sort_pairs<hip_exec> is only implemented for pointers");
626 static_assert(std::is_pointer<ValIter>::value,
627 "sort_pairs<hip_exec> is only implemented for pointers");
630 type_traits::is_arithmetic<K>::value,
631 "sort_pairs<hip_exec> is only implemented for arithmetic types");
633 concepts::any_of<camp::is_same<Compare, operators::less<K>>,
634 camp::is_same<Compare, operators::greater<K>>>::value,
635 "sort_pairs<hip_exec> is only implemented for RAJA::operators::less or "
636 "RAJA::operators::greater");
638 return resources::EventProxy<resources::Hip>(hip_res);
644 template<
typename IterationMapping,
645 typename IterationGetter,
646 typename Concretizer,
650 concepts::enable_if_t<
651 resources::EventProxy<resources::Hip>,
652 type_traits::is_arithmetic<RAJA::detail::IterVal<KeyIter>>,
653 std::is_pointer<KeyIter>,
654 std::is_pointer<ValIter>>
656 resources::Hip hip_res,
657 ::RAJA::policy::hip::
658 hip_exec<IterationMapping, IterationGetter, Concretizer, Async> p,
664 return stable_pairs(hip_res, p, keys_begin, keys_end, vals_begin, comp);
670 template<
typename IterationMapping,
671 typename IterationGetter,
672 typename Concretizer,
676 concepts::enable_if_t<
677 resources::EventProxy<resources::Hip>,
678 type_traits::is_arithmetic<RAJA::detail::IterVal<KeyIter>>,
679 std::is_pointer<KeyIter>,
680 std::is_pointer<ValIter>>
682 resources::Hip hip_res,
683 ::RAJA::policy::hip::
684 hip_exec<IterationMapping, IterationGetter, Concretizer, Async> p,
690 return stable_pairs(hip_res, p, keys_begin, keys_end, vals_begin, comp);
Header file defining prototypes for routines used to manage memory for HIP reductions and other opera...
Header file for RAJA operator definitions.
Header file for RAJA algorithm definitions.
Header file for RAJA concept definitions.
Header file containing RAJA HIP policy definitions.
typename ::std::iterator_traits< Iter >::value_type IterVal
Definition: algorithm.hpp:38
concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< ExecPolicy > > stable_pairs(resources::Host host_res, const ExecPolicy &, KeyIter keys_begin, KeyIter keys_end, ValIter vals_begin, Compare comp)
Definition: sort.hpp:276
concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< ExecPolicy > > stable(resources::Host host_res, const ExecPolicy &, Iter begin, Iter end, Compare comp)
stable sort given range using comparison function
Definition: sort.hpp:230
concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< ExecPolicy > > unstable(resources::Host host_res, const ExecPolicy &, Iter begin, Iter end, Compare comp)
sort given range using comparison function
Definition: sort.hpp:213
concepts::enable_if_t< resources::EventProxy< resources::Host >, type_traits::is_openmp_policy< ExecPolicy > > unstable_pairs(resources::Host host_res, const ExecPolicy &, KeyIter keys_begin, KeyIter keys_end, ValIter vals_begin, Compare comp)
sort given range of pairs using comparison function on keys
Definition: sort.hpp:250
concepts::enable_if_t< resources::EventProxy< Res >, type_traits::is_execution_policy< ExecPolicy >, type_traits::is_resource< Res >, std::is_constructible< camp::resources::Resource, Res >, type_traits::is_range< Container > > sort(ExecPolicy &&p, Res r, Container &&c, Compare comp=Compare {})
sort execution pattern
Definition: sort.hpp:61
Definition: AlignedRangeIndexSetBuilders.cpp:35
void launch(LaunchParams const &launch_params, ReduceParams &&... rest_of_launch_args)
Definition: launch_core.hpp:268