RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
Dispatcher.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_hip_WorkGroup_Dispatcher_HPP
21 #define RAJA_hip_WorkGroup_Dispatcher_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "camp/resource.hpp"
26 
28 
30 
31 #include <thread>
32 #include <mutex>
33 
34 namespace RAJA
35 {
36 
37 namespace detail
38 {
39 
40 namespace hip
41 {
42 
43 // global function that creates the value on the device using the
44 // factory and writes it into a pinned ptr
45 template<typename Factory>
46 __global__ void get_value_global(typename Factory::value_type* ptr,
47  Factory factory)
48 {
49  *ptr = factory();
50 }
51 
52 // get the pinned ptr buffer
53 inline void* get_cached_value_ptr(size_t nbytes)
54 {
55  static size_t cached_nbytes = 0;
56  static void* ptr = nullptr;
57  if (nbytes > cached_nbytes)
58  {
59  cached_nbytes = 0;
60  CAMP_HIP_API_INVOKE_AND_CHECK(hipHostFree, ptr);
61  CAMP_HIP_API_INVOKE_AND_CHECK(hipHostMalloc, &ptr, nbytes);
62  cached_nbytes = nbytes;
63  }
64  return ptr;
65 }
66 
67 // mutex that guards against concurrent use of
68 // pinned buffer and get_cached_value_ptr()
69 inline std::mutex& get_value_mutex()
70 {
71  static std::mutex s_mutex;
72  return s_mutex;
73 }
74 
75 // get the device function pointer by calling a global function to
76 // write it into a pinned ptr, beware different instantiates of this
77 // function may run concurrently
78 template<typename Factory>
79 inline auto get_value(Factory&& factory)
80 {
81  using value_type = typename std::decay_t<Factory>::value_type;
82  const std::lock_guard<std::mutex> lock(get_value_mutex());
83 
84  auto res = ::camp::resources::Hip::get_default();
85  auto ptr = static_cast<value_type*>(get_cached_value_ptr(sizeof(value_type)));
86  auto func =
87  reinterpret_cast<const void*>(&get_value_global<std::decay_t<Factory>>);
88  void* args[] = {(void*)&ptr, (void*)&factory};
89  CAMP_HIP_API_INVOKE_AND_CHECK(hipLaunchKernel, func, 1, 1, args, 0,
90  res.get_stream());
91  CAMP_HIP_API_INVOKE_AND_CHECK(hipStreamSynchronize, res.get_stream());
92 
93  return *ptr;
94 }
95 
96 // get the device function pointer and store it so it can be used
97 // multiple times
98 template<typename Factory>
99 inline auto get_cached_value(Factory&& factory)
100 {
101  static auto value = get_value(std::forward<Factory>(factory));
102  return value;
103 }
104 
105 } // namespace hip
106 
110 template<typename T, typename Dispatcher_T, size_t BLOCK_SIZE, bool Async>
111 inline const Dispatcher_T* get_Dispatcher(hip_work<BLOCK_SIZE, Async> const&)
112 {
113  static Dispatcher_T dispatcher {
114  Dispatcher_T::template makeDispatcher<T>([](auto&& factory) {
115  return hip::get_cached_value(std::forward<decltype(factory)>(factory));
116  })};
117  return &dispatcher;
118 }
119 
120 } // namespace detail
121 
122 } // namespace RAJA
123 
124 #endif // closing endif for header file include guard
Header file containing RAJA HIP policy definitions.
__global__ void get_value_global(typename Factory::value_type *ptr, Factory factory)
Definition: Dispatcher.hpp:46
void * get_cached_value_ptr(size_t nbytes)
Definition: Dispatcher.hpp:53
auto get_cached_value(Factory &&factory)
Definition: Dispatcher.hpp:99
auto get_value(Factory &&factory)
Definition: Dispatcher.hpp:79
std::mutex & get_value_mutex()
Definition: Dispatcher.hpp:69
Args args
Definition: WorkRunner.hpp:212
const Dispatcher_T * get_Dispatcher(cuda_work_explicit< BLOCK_SIZE, BLOCKS_PER_SM, Async > const &)
Definition: Dispatcher.hpp:115
Definition: AlignedRangeIndexSetBuilders.cpp:35
Header file providing RAJA Dispatcher for workgroup.