20 #ifndef RAJA_PATTERN_WORKGROUP_Dispatcher_HPP
21 #define RAJA_PATTERN_WORKGROUP_Dispatcher_HPP
24 #include "RAJA/config.hpp"
28 #include "camp/number.hpp"
29 #include "camp/list.hpp"
30 #include "camp/helpers.hpp"
62 return !(platform == Platform::cuda || platform == Platform::hip);
67 template<
typename dispatch_policy,
typename holder_type>
70 template<
typename dispatch_policy,
typename holder_type>
81 template<Platform platform,
82 typename dispatch_policy,
83 typename DispatcherID,
87 template<
typename holder_type>
103 template<Platform platform,
typename DispatcherID,
typename... CallArgs>
122 T* dest_as_T =
static_cast<T*
>(dest.
ptr);
123 T* src_as_T =
static_cast<T*
>(src.
ptr);
124 new (dest_as_T) T(std::move(*src_as_T));
134 const T* obj_as_T =
static_cast<const T*
>(obj.
ptr);
135 (*obj_as_T)(std::forward<CallArgs>(
args)...);
143 const T* obj_as_T =
static_cast<const T*
>(obj.
ptr);
144 (*obj_as_T)(std::forward<CallArgs>(
args)...);
153 T* obj_as_T =
static_cast<T*
>(obj.
ptr);
165 struct DeviceInvokerFactory
171 #if defined(RAJA_ENABLE_HIP) && !defined(RAJA_ENABLE_HIP_INDIRECT_FUNCTION_CALL)
174 return &s_device_invoke<T>;
183 bool uhi = use_host_invoke,
184 std::enable_if_t<uhi>* =
nullptr>
187 return {
mover_type {&s_move_construct_destroy<T>},
203 typename CreateOnDevice,
204 bool uhi = use_host_invoke,
205 std::enable_if_t<!uhi>* =
nullptr>
208 return {
mover_type {&s_move_construct_destroy<T>},
209 invoker_type {std::forward<CreateOnDevice>(createOnDevice)(
210 DeviceInvokerFactory<T> {})},
220 template<
typename holder_type>
236 template<Platform platform,
typename DispatcherID,
typename... CallArgs>
254 struct host_impl_base
259 struct device_impl_base
262 CallArgs...
args)
const = 0;
266 struct base_impl_type : impl_base
275 T* dest_as_T =
static_cast<T*
>(dest.
ptr);
276 T* src_as_T =
static_cast<T*
>(src.
ptr);
277 new (dest_as_T) T(std::move(*src_as_T));
286 T* obj_as_T =
static_cast<T*
>(obj.
ptr);
292 struct host_impl_type : host_impl_base
299 const T* obj_as_T =
static_cast<const T*
>(obj.
ptr);
300 (*obj_as_T)(std::forward<CallArgs>(
args)...);
305 struct device_impl_type : device_impl_base
311 CallArgs...
args)
const override
313 const T* obj_as_T =
static_cast<const T*
>(obj.
ptr);
314 (*obj_as_T)(std::forward<CallArgs>(
args)...);
324 m_impl->move_destroy(dest, src);
328 struct host_invoker_type
334 m_impl->invoke(obj, std::forward<CallArgs>(
args)...);
339 struct device_invoker_type
345 m_impl->invoke(obj, std::forward<CallArgs>(
args)...);
350 conditional_t<use_host_invoke, host_invoker_type, device_invoker_type>;
352 struct destroyer_type
361 struct DeviceImplTypeFactory
367 #if defined(RAJA_ENABLE_HIP) && !defined(RAJA_ENABLE_HIP_INDIRECT_FUNCTION_CALL)
370 static device_impl_type<T> s_device_impl;
371 return &s_device_impl;
380 bool uhi = use_host_invoke,
381 std::enable_if_t<uhi>* =
nullptr>
384 static base_impl_type<T> s_base_impl;
385 static host_impl_type<T> s_host_impl;
386 return {mover_type {&s_base_impl}, host_invoker_type {&s_host_impl},
387 destroyer_type {&s_base_impl},
sizeof(T)};
401 typename CreateOnDevice,
402 bool uhi = use_host_invoke,
403 std::enable_if_t<!uhi>* =
nullptr>
406 static base_impl_type<T> s_base_impl;
407 static device_impl_type<T>* s_device_impl_ptr {std::forward<CreateOnDevice>(
408 createOnDevice)(DeviceImplTypeFactory<T> {})};
409 return {mover_type {&s_base_impl}, device_invoker_type {s_device_impl_ptr},
410 destroyer_type {&s_base_impl},
sizeof(T)};
420 template<
typename... Ts,
typename holder_type>
431 template<Platform platform,
typename DispatcherID,
typename... CallArgs>
454 struct host_invoker_type
459 struct device_invoker_type
465 conditional_t<use_host_invoke, host_invoker_type, device_invoker_type>;
470 struct destroyer_type
479 bool uhi = use_host_invoke,
480 std::enable_if_t<uhi>* =
nullptr>
483 return {mover_type {}, host_invoker_type {}, destroyer_type {},
sizeof(T)};
493 typename CreateOnDevice,
494 bool uhi = use_host_invoke,
495 std::enable_if_t<!uhi>* =
nullptr>
498 return {mover_type {}, device_invoker_type {}, destroyer_type {},
512 template<Platform platform,
514 typename DispatcherID,
515 typename... CallArgs>
534 T* dest_as_T =
static_cast<T*
>(dest.
ptr);
535 T* src_as_T =
static_cast<T*
>(src.
ptr);
536 new (dest_as_T) T(std::move(*src_as_T));
544 struct host_invoker_type
548 const T* obj_as_T =
static_cast<const T*
>(obj.
ptr);
549 (*obj_as_T)(std::forward<CallArgs>(
args)...);
553 struct device_invoker_type
557 const T* obj_as_T =
static_cast<const T*
>(obj.
ptr);
558 (*obj_as_T)(std::forward<CallArgs>(
args)...);
563 conditional_t<use_host_invoke, host_invoker_type, device_invoker_type>;
568 struct destroyer_type
572 T* obj_as_T =
static_cast<T*
>(obj.
ptr);
581 bool uhi = use_host_invoke,
582 std::enable_if_t<uhi>* =
nullptr>
585 static_assert(std::is_same<T, U>::value,
586 "U must be in direct_dispatch types");
587 return {mover_type {}, host_invoker_type {}, destroyer_type {},
sizeof(T)};
597 typename CreateOnDevice,
598 bool uhi = use_host_invoke,
599 std::enable_if_t<!uhi>* =
nullptr>
602 static_assert(std::is_same<T, U>::value,
603 "U must be in direct_dispatch types");
604 return {mover_type {}, device_invoker_type {}, destroyer_type {},
618 template<
typename T0,
622 typename DispatcherID,
623 typename... CallArgs>
652 template<
int... id_types,
typename... Ts>
653 void impl_helper(camp::int_seq<int, id_types...>,
658 camp::sink(((id_types ==
id) ? (impl<Ts>(dest, src), 0) : 0)...);
662 void impl(void_ptr_wrapper dest, void_ptr_wrapper src)
const
664 T* dest_as_T =
static_cast<T*
>(dest.ptr);
665 T* src_as_T =
static_cast<T*
>(src.ptr);
666 new (dest_as_T) T(std::move(*src_as_T));
674 struct host_invoker_type
681 std::forward<CallArgs>(
args)...);
685 template<
int... id_types,
typename... Ts>
686 void impl_helper(camp::int_seq<int, id_types...>,
689 CallArgs...
args)
const
691 camp::sink(((id_types ==
id)
692 ? (impl<Ts>(obj, std::forward<CallArgs>(
args)...), 0)
697 void impl(void_cptr_wrapper obj, CallArgs...
args)
const
699 const T* obj_as_T =
static_cast<const T*
>(obj.ptr);
700 (*obj_as_T)(std::forward<CallArgs>(
args)...);
704 struct device_invoker_type
711 std::forward<CallArgs>(
args)...);
715 template<
int... id_types,
typename... Ts>
716 RAJA_DEVICE void impl_helper(camp::int_seq<int, id_types...>,
719 CallArgs...
args)
const
721 camp::sink(((id_types ==
id)
722 ? (impl<Ts>(obj, std::forward<CallArgs>(
args)...), 0)
729 const T* obj_as_T =
static_cast<const T*
>(obj.ptr);
730 (*obj_as_T)(std::forward<CallArgs>(
args)...);
735 conditional_t<use_host_invoke, host_invoker_type, device_invoker_type>;
740 struct destroyer_type
750 template<
int... id_types,
typename... Ts>
751 void impl_helper(camp::int_seq<int, id_types...>,
755 camp::sink(((id_types ==
id) ? (impl<Ts>(obj), 0) : 0)...);
759 void impl(void_ptr_wrapper obj)
const
761 T* obj_as_T =
static_cast<T*
>(obj.ptr);
772 template<
typename T,
int... id_types,
typename... Ts>
779 (std::is_same<T, Ts>::value ? ((
id = id_types), 0) : 0)...};
788 bool uhi = use_host_invoke,
789 std::enable_if_t<uhi>* =
nullptr>
794 static_assert(
id !=
id_type(-1),
"T must be in direct_dispatch types");
795 return {mover_type {
id}, host_invoker_type {
id}, destroyer_type {
id},
806 typename CreateOnDevice,
807 bool uhi = use_host_invoke,
808 std::enable_if_t<!uhi>* =
nullptr>
813 static_assert(
id !=
id_type(-1),
"T must be in direct_dispatch types");
814 return {mover_type {
id}, device_invoker_type {
id}, destroyer_type {
id},
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_DEVICE
Definition: macros.hpp:66
Args args
Definition: WorkRunner.hpp:212
constexpr bool dispatcher_use_host_invoke(Platform platform)
Definition: Dispatcher.hpp:60
typename dispatcher_transform_types< dispatch_policy, holder_type >::type dispatcher_transform_types_t
Definition: Dispatcher.hpp:72
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA wrapper for "multi-policy" and dynamic policy selection.
Definition: Dispatcher.hpp:52
DispatcherVoidConstPtrWrapper()=default
RAJA_HOST_DEVICE DispatcherVoidConstPtrWrapper(const void *p)
Definition: Dispatcher.hpp:57
const void * ptr
Definition: Dispatcher.hpp:53
Definition: Dispatcher.hpp:42
void * ptr
Definition: Dispatcher.hpp:43
DispatcherVoidPtrWrapper()=default
RAJA_HOST_DEVICE DispatcherVoidPtrWrapper(void *p)
Definition: Dispatcher.hpp:47
Definition: Dispatcher.hpp:85
Definition: WorkGroup.hpp:94
Dispatch using function pointers to make indirect function calls.
Definition: WorkGroup.hpp:77
Dispatch using virtual functions to make indirect function calls.
Definition: WorkGroup.hpp:83