RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
InitLocalMem.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 
21 #ifndef RAJA_policy_hip_kernel_InitLocalMem_HPP
22 #define RAJA_policy_hip_kernel_InitLocalMem_HPP
23 
24 #include "RAJA/config.hpp"
25 
26 #include <iostream>
27 #include <type_traits>
28 
29 #include "RAJA/util/macros.hpp"
30 #include "RAJA/util/types.hpp"
31 
34 
35 namespace RAJA
36 {
37 
38 struct hip_thread_mem;
39 struct hip_shared_mem;
40 
41 namespace internal
42 {
43 
44 // Intialize thread shared array
45 template<typename Data,
46  camp::idx_t... Indices,
47  typename... EnclosedStmts,
48  typename Types>
49 struct HipStatementExecutor<Data,
50  statement::InitLocalMem<RAJA::hip_shared_mem,
51  camp::idx_seq<Indices...>,
52  EnclosedStmts...>,
53  Types>
54 {
55 
56  using stmt_list_t = StatementList<EnclosedStmts...>;
57  using enclosed_stmts_t = HipStatementListExecutor<Data, stmt_list_t, Types>;
58 
59  // Launch loops
60  template<camp::idx_t Pos>
61  static inline RAJA_DEVICE void initMem(Data& data, bool thread_active)
62  {
63  using varType = typename camp::tuple_element_t<
64  Pos, typename camp::decay<Data>::param_tuple_t>::value_type;
65  const camp::idx_t NumElem = camp::tuple_element_t<
66  Pos, typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
67 
68  __shared__ varType Array[NumElem];
69  camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
70 
71  enclosed_stmts_t::exec(data, thread_active);
72  }
73 
74  // Intialize local array
75  // Identifies type + number of elements needed
76  template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
77  static inline RAJA_DEVICE void initMem(Data& data, bool thread_active)
78  {
79  using varType = typename camp::tuple_element_t<
80  Pos, typename camp::decay<Data>::param_tuple_t>::value_type;
81  const camp::idx_t NumElem = camp::tuple_element_t<
82  Pos, typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
83 
84  __shared__ varType Array[NumElem];
85  camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
86  initMem<other0, others...>(data, thread_active);
87  }
88 
89  // Set pointer to null base case
90  template<camp::idx_t Pos>
91  static inline RAJA_DEVICE void setPtrToNull(Data& data)
92  {
93 
94  camp::get<Pos>(data.param_tuple).set_data(nullptr);
95  }
96 
97  // Set pointer to null recursive case
98  template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
99  static inline RAJA_DEVICE void setPtrToNull(Data& data)
100  {
101 
102  camp::get<Pos>(data.param_tuple).set_data(nullptr);
103  setPtrToNull<other0, others...>(data);
104  }
105 
106  static inline RAJA_DEVICE void exec(Data& data, bool thread_active)
107  {
108 
109  // Intialize scoped arrays + launch loops
110  initMem<Indices...>(data, thread_active);
111 
112  // set pointers in scoped arrays to null
113  setPtrToNull<Indices...>(data);
114  }
115 
116  inline static LaunchDims calculateDimensions(Data const& data)
117  {
118  return enclosed_stmts_t::calculateDimensions(data);
119  }
120 };
121 
122 // Intialize thread private array
123 template<typename Data,
124  camp::idx_t... Indices,
125  typename... EnclosedStmts,
126  typename Types>
127 struct HipStatementExecutor<Data,
128  statement::InitLocalMem<RAJA::hip_thread_mem,
129  camp::idx_seq<Indices...>,
130  EnclosedStmts...>,
131  Types>
132 {
133 
134  using stmt_list_t = StatementList<EnclosedStmts...>;
135  using enclosed_stmts_t = HipStatementListExecutor<Data, stmt_list_t, Types>;
136 
137  // Launch loops
138  template<camp::idx_t Pos>
139  static inline RAJA_DEVICE void initMem(Data& data, bool thread_active)
140  {
141  using varType = typename camp::tuple_element_t<
142  Pos, typename camp::decay<Data>::param_tuple_t>::value_type;
143  const camp::idx_t NumElem = camp::tuple_element_t<
144  Pos, typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
145 
146  varType Array[NumElem];
147  camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
148 
149  enclosed_stmts_t::exec(data, thread_active);
150  }
151 
152  // Intialize local array
153  // Identifies type + number of elements needed
154  template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
155  static inline RAJA_DEVICE void initMem(Data& data, bool thread_active)
156  {
157  using varType = typename camp::tuple_element_t<
158  Pos, typename camp::decay<Data>::param_tuple_t>::value_type;
159  const camp::idx_t NumElem = camp::tuple_element_t<
160  Pos, typename camp::decay<Data>::param_tuple_t>::layout_type::s_size;
161 
162  varType Array[NumElem];
163  camp::get<Pos>(data.param_tuple).set_data(&Array[0]);
164  initMem<other0, others...>(data, thread_active);
165  }
166 
167  // Set pointer to null base case
168  template<camp::idx_t Pos>
169  static inline RAJA_DEVICE void setPtrToNull(Data& data)
170  {
171 
172  camp::get<Pos>(data.param_tuple).set_data(nullptr);
173  }
174 
175  // Set pointer to null recursive case
176  template<camp::idx_t Pos, camp::idx_t other0, camp::idx_t... others>
177  static inline RAJA_DEVICE void setPtrToNull(Data& data)
178  {
179 
180  camp::get<Pos>(data.param_tuple).set_data(nullptr);
181  setPtrToNull<other0, others...>(data);
182  }
183 
184  static inline RAJA_DEVICE void exec(Data& data, bool thread_active)
185  {
186 
187  // Intialize scoped arrays + launch loops
188  initMem<Indices...>(data, thread_active);
189 
190  // set pointers in scoped arrays to null
191  setPtrToNull<Indices...>(data);
192  }
193 
194  inline static LaunchDims calculateDimensions(Data const& data)
195  {
196  return enclosed_stmts_t::calculateDimensions(data);
197  }
198 };
199 
200 
201 } // namespace internal
202 } // end namespace RAJA
203 
204 
205 #endif
Header file for common RAJA internal macro definitions.
#define RAJA_DEVICE
Definition: macros.hpp:66
camp::list< Stmts... > StatementList
Definition: StatementList.hpp:41
Definition: AlignedRangeIndexSetBuilders.cpp:35
Header file for shared memory window.
RAJA header file containing constructs used to run kernel traversals on GPU with HIP.
Header file for RAJA type definitions.