RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
Dispatcher.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_cuda_WorkGroup_Dispatcher_HPP
21 #define RAJA_cuda_WorkGroup_Dispatcher_HPP
22 
23 #include "RAJA/config.hpp"
24 
25 #include "camp/resource.hpp"
26 
28 
30 
31 #include <thread>
32 #include <mutex>
33 
34 namespace RAJA
35 {
36 
37 namespace detail
38 {
39 
40 namespace cuda
41 {
42 
43 // global function that creates the value on the device using the
44 // factory and writes it into a pinned ptr
45 template<typename Factory>
46 __global__ void get_value_global(typename Factory::value_type* ptr,
47  Factory factory)
48 {
49  *ptr = factory();
50 }
51 
52 // get the pinned ptr buffer
53 inline void* get_cached_value_ptr(size_t nbytes)
54 {
55  static size_t cached_nbytes = 0;
56  static void* ptr = nullptr;
57  if (nbytes > cached_nbytes)
58  {
59  cached_nbytes = 0;
60  CAMP_CUDA_API_INVOKE_AND_CHECK(cudaFreeHost, ptr);
61  CAMP_CUDA_API_INVOKE_AND_CHECK(cudaMallocHost, &ptr, nbytes);
62  cached_nbytes = nbytes;
63  }
64  return ptr;
65 }
66 
67 // mutex that guards against concurrent use of
68 // pinned buffer and get_cached_value_ptr()
69 inline std::mutex& get_value_mutex()
70 {
71  static std::mutex s_mutex;
72  return s_mutex;
73 }
74 
75 // get the device function pointer by calling a global function to
76 // write it into a pinned ptr, beware different instantiates of this
77 // function may run concurrently
78 template<typename Factory>
79 inline auto get_value(Factory&& factory)
80 {
81  using value_type = typename std::decay_t<Factory>::value_type;
82  const std::lock_guard<std::mutex> lock(get_value_mutex());
83 
84  auto res = ::camp::resources::Cuda::get_default();
85  auto ptr = static_cast<value_type*>(get_cached_value_ptr(sizeof(value_type)));
86  auto func =
87  reinterpret_cast<const void*>(&get_value_global<std::decay_t<Factory>>);
88  void* args[] = {(void*)&ptr, (void*)&factory};
89  CAMP_CUDA_API_INVOKE_AND_CHECK(cudaLaunchKernel, func, 1, 1, args, 0,
90  res.get_stream());
91  CAMP_CUDA_API_INVOKE_AND_CHECK(cudaStreamSynchronize, res.get_stream());
92 
93  return *ptr;
94 }
95 
96 // get the device function pointer and store it so it can be used
97 // multiple times
98 template<typename Factory>
99 inline auto get_cached_value(Factory&& factory)
100 {
101  static auto value = get_value(std::forward<Factory>(factory));
102  return value;
103 }
104 
105 } // namespace cuda
106 
110 template<typename T,
111  typename Dispatcher_T,
112  size_t BLOCK_SIZE,
113  size_t BLOCKS_PER_SM,
114  bool Async>
115 inline const Dispatcher_T* get_Dispatcher(
116  cuda_work_explicit<BLOCK_SIZE, BLOCKS_PER_SM, Async> const&)
117 {
118  static Dispatcher_T dispatcher {
119  Dispatcher_T::template makeDispatcher<T>([](auto&& factory) {
120  return cuda::get_cached_value(std::forward<decltype(factory)>(factory));
121  })};
122  return &dispatcher;
123 }
124 
125 } // namespace detail
126 
127 } // namespace RAJA
128 
129 #endif // closing endif for header file include guard
Header file containing RAJA CUDA policy definitions.
auto get_value(Factory &&factory)
Definition: Dispatcher.hpp:79
std::mutex & get_value_mutex()
Definition: Dispatcher.hpp:69
void * get_cached_value_ptr(size_t nbytes)
Definition: Dispatcher.hpp:53
auto get_cached_value(Factory &&factory)
Definition: Dispatcher.hpp:99
__global__ void get_value_global(typename Factory::value_type *ptr, Factory factory)
Definition: Dispatcher.hpp:46
Args args
Definition: WorkRunner.hpp:212
const Dispatcher_T * get_Dispatcher(cuda_work_explicit< BLOCK_SIZE, BLOCKS_PER_SM, Async > const &)
Definition: Dispatcher.hpp:115
Definition: AlignedRangeIndexSetBuilders.cpp:35
Header file providing RAJA Dispatcher for workgroup.