RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
MatrixMatrixMultiply.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_internal_MatrixMatrixMultiply_HPP
21 #define RAJA_pattern_tensor_internal_MatrixMatrixMultiply_HPP
22 
23 #include "camp/camp.hpp"
24 #include "RAJA/config.hpp"
26 
27 namespace RAJA
28 {
29 namespace internal
30 {
31 namespace expt
32 {
33 
34 
35 template<typename MATA, typename MATB>
37 
43 template<typename T,
44  typename REGISTER_POLICY,
45  camp::idx_t N_SIZE,
46  camp::idx_t M_SIZE,
47  camp::idx_t M2_SIZE,
48  camp::idx_t O_SIZE>
50  RAJA::expt::TensorRegister<REGISTER_POLICY,
51  T,
52  RAJA::expt::RowMajorLayout,
53  camp::idx_seq<N_SIZE, M_SIZE>>,
54  RAJA::expt::TensorRegister<REGISTER_POLICY,
55  T,
56  RAJA::expt::RowMajorLayout,
57  camp::idx_seq<M2_SIZE, O_SIZE>>>
58 {
59 
60  static_assert(M_SIZE == M2_SIZE,
61  "Matrices are not compatible for multiplication");
62 
63  using left_type = RAJA::expt::TensorRegister<REGISTER_POLICY,
64  T,
66  camp::idx_seq<N_SIZE, M_SIZE>>;
67 
68  using right_type = RAJA::expt::TensorRegister<REGISTER_POLICY,
69  T,
71  camp::idx_seq<M_SIZE, O_SIZE>>;
72 
73  using result_type = RAJA::expt::TensorRegister<REGISTER_POLICY,
74  T,
76  camp::idx_seq<N_SIZE, O_SIZE>>;
77 
78  using register_type = typename result_type::register_type;
79 
80  static constexpr camp::idx_t s_elements_per_register =
81  left_type::s_elements_per_register;
82  static constexpr camp::idx_t s_A_minor_dim_registers =
83  left_type::s_minor_dim_registers;
84  static constexpr camp::idx_t s_B_minor_dim_registers =
85  right_type::s_minor_dim_registers;
86  static constexpr camp::idx_t s_C_minor_dim_registers =
87  result_type::s_minor_dim_registers;
88 
89  /*
90  * Matrix B (and C) has 1 more more registers per row
91  *
92  */
93  template<typename dummy = void>
94  RAJA_HOST_DEVICE static RAJA_INLINE
95  typename std::enable_if<(s_C_minor_dim_registers != 0), dummy>::type
97  right_type const& B,
98  result_type& C)
99  {
100 #if defined(RAJA_ENABLE_VECTOR_STATS) && !defined(__CUDA_ARCH__)
101  RAJA::tensor_stats::num_matrix_mm_multacc_row_row++;
102 #endif
103 
104  constexpr camp::idx_t num_bc_reg_per_row = s_C_minor_dim_registers;
105 
106  RAJA_UNROLL
107  for (camp::idx_t c_reg = 0; c_reg < result_type::s_num_registers; ++c_reg)
108  {
109  camp::idx_t bc_col_reg = c_reg % num_bc_reg_per_row;
110  camp::idx_t ac_row = c_reg / num_bc_reg_per_row;
111 
112  RAJA_UNROLL
113  for (camp::idx_t a_col = 0; a_col < M_SIZE; ++a_col)
114  {
115  camp::idx_t b_reg = a_col * num_bc_reg_per_row + bc_col_reg;
116 
117  C.get_register(c_reg) =
118  register_type(A.get(ac_row, a_col))
119  .multiply_add(B.get_register(b_reg), C.get_register(c_reg));
120  }
121  }
122  }
123 
124  /*
125  * Matrix B (and C) have less than one register per row
126  *
127  */
128  template<typename dummy = void>
129  RAJA_HOST_DEVICE RAJA_INLINE static
130  typename std::enable_if<(s_C_minor_dim_registers == 0), dummy>::type
132  right_type const& B,
133  result_type& C)
134  {
135  constexpr camp::idx_t bc_segbits = result_type::s_segbits;
136  constexpr camp::idx_t a_segments_per_register = 1 << bc_segbits;
137 
138  RAJA_UNROLL
139  for (camp::idx_t ac_row = 0; ac_row < N_SIZE; ++ac_row)
140  {
141  camp::idx_t c_reg = ac_row / result_type::s_major_dim_per_register;
142  camp::idx_t c_segment = ac_row % result_type::s_major_dim_per_register;
143  register_type c_tmp;
144 
145  RAJA_UNROLL
146  for (camp::idx_t b_reg = 0; b_reg < right_type::s_num_registers; ++b_reg)
147  {
148 
149  camp::idx_t a_segment = ac_row * right_type::s_num_registers + b_reg;
150  camp::idx_t a_reg = a_segment / a_segments_per_register;
151  camp::idx_t a_reg_segment = a_segment % a_segments_per_register;
152 
153  auto a_tmp = A.get_register(a_reg).segmented_broadcast_outer(
154  bc_segbits, a_reg_segment);
155 
156  if (b_reg == 0)
157  {
158 
159  c_tmp = a_tmp.multiply(B.get_register(b_reg));
160  }
161  else
162  {
163  c_tmp = a_tmp.multiply_add(B.get_register(b_reg), c_tmp);
164  }
165  }
166 
167  C.get_register(c_reg) += c_tmp.segmented_sum_outer(bc_segbits, c_segment);
168  }
169  }
170 
172  static RAJA_INLINE void multiply(left_type const& A,
173  right_type const& B,
174  result_type& C)
175  {
176  C = result_type(0);
177  multiply_accumulate(A, B, C);
178  }
179 };
180 
186 template<typename T,
187  typename REGISTER_POLICY,
188  camp::idx_t N_SIZE,
189  camp::idx_t M_SIZE,
190  camp::idx_t M2_SIZE,
191  camp::idx_t O_SIZE>
193  RAJA::expt::TensorRegister<REGISTER_POLICY,
194  T,
195  RAJA::expt::ColMajorLayout,
196  camp::idx_seq<N_SIZE, M_SIZE>>,
197  RAJA::expt::TensorRegister<REGISTER_POLICY,
198  T,
199  RAJA::expt::ColMajorLayout,
200  camp::idx_seq<M2_SIZE, O_SIZE>>>
201 {
202 
204  RAJA::expt::TensorRegister<REGISTER_POLICY,
205  T,
207  camp::idx_seq<N_SIZE, M_SIZE>>,
208  RAJA::expt::TensorRegister<REGISTER_POLICY,
209  T,
211  camp::idx_seq<M2_SIZE, O_SIZE>>>;
212 
213  static_assert(M_SIZE == M2_SIZE,
214  "Matrices are not compatible for multiplication");
215 
216  using left_type = RAJA::expt::TensorRegister<REGISTER_POLICY,
217  T,
219  camp::idx_seq<N_SIZE, M_SIZE>>;
220 
221  using right_type = RAJA::expt::TensorRegister<REGISTER_POLICY,
222  T,
224  camp::idx_seq<M_SIZE, O_SIZE>>;
225 
226  using result_type = RAJA::expt::TensorRegister<REGISTER_POLICY,
227  T,
229  camp::idx_seq<N_SIZE, O_SIZE>>;
230 
231  using register_type = typename result_type::register_type;
232 
233  static constexpr camp::idx_t s_elements_per_register =
234  left_type::s_elements_per_register;
235  static constexpr camp::idx_t s_A_minor_dim_registers =
236  left_type::s_minor_dim_registers;
237  static constexpr camp::idx_t s_B_minor_dim_registers =
238  right_type::s_minor_dim_registers;
239  static constexpr camp::idx_t s_C_minor_dim_registers =
240  result_type::s_minor_dim_registers;
241 
242  /*
243  * Matrix A (and C) has 1 more more registers per column
244  *
245  */
246  template<typename dummy = void>
247  RAJA_HOST_DEVICE static RAJA_INLINE
248  typename std::enable_if<(s_C_minor_dim_registers != 0), dummy>::type
250  right_type const& B,
251  result_type& C)
252  {
253 
254 #if defined(RAJA_ENABLE_VECTOR_STATS) && !defined(__CUDA_ARCH__)
255  RAJA::tensor_stats::num_matrix_mm_multacc_row_row++;
256 #endif
257 
258 
259  constexpr camp::idx_t num_ac_reg_per_col = s_C_minor_dim_registers;
260 
261  RAJA_UNROLL
262  for (camp::idx_t c_reg = 0; c_reg < result_type::s_num_registers; ++c_reg)
263  {
264  camp::idx_t ac_row_reg = c_reg % num_ac_reg_per_col;
265  camp::idx_t bc_col = c_reg / num_ac_reg_per_col;
266 
267  RAJA_UNROLL
268  for (camp::idx_t b_row = 0; b_row < M_SIZE; ++b_row)
269  {
270  camp::idx_t a_reg = b_row * num_ac_reg_per_col + ac_row_reg;
271 
272  C.get_register(c_reg) =
273  register_type(B.get(b_row, bc_col))
274  .multiply_add(A.get_register(a_reg), C.get_register(c_reg));
275  }
276  }
277  }
278 
279  /*
280  * Matrix A (and C) have less than one register per column
281  *
282  */
283  template<typename dummy = void>
284  RAJA_HOST_DEVICE RAJA_INLINE static
285  typename std::enable_if<(s_C_minor_dim_registers == 0), dummy>::type
287  right_type const& B,
288  result_type& C)
289  {
290  constexpr camp::idx_t ac_segbits = result_type::s_segbits;
291  constexpr camp::idx_t b_segments_per_register = 1 << ac_segbits;
292 
293  camp::idx_t bc_col = 0;
294 
295  RAJA_UNROLL
296  for (camp::idx_t c_reg = 0;
297  c_reg < N_SIZE / result_type::s_major_dim_per_register; ++c_reg)
298  {
299 
300  RAJA_UNROLL
301  for (camp::idx_t c_segment = 0;
302  c_segment < result_type::s_major_dim_per_register; ++c_segment)
303  {
304 
305  register_type c_tmp;
306 
307  RAJA_UNROLL
308  for (camp::idx_t a_reg = 0; a_reg < right_type::s_num_registers;
309  ++a_reg)
310  {
311 
312 
313  camp::idx_t b_segment = bc_col * right_type::s_num_registers + a_reg;
314  camp::idx_t b_reg = b_segment / b_segments_per_register;
315  camp::idx_t b_reg_segment = b_segment % b_segments_per_register;
316 
317  register_type b_tmp = B.get_register(b_reg).segmented_broadcast_outer(
318  ac_segbits, b_reg_segment);
319 
320  if (a_reg == 0)
321  {
322  c_tmp = b_tmp.multiply(A.get_register(a_reg));
323  }
324  else
325  {
326  c_tmp = b_tmp.multiply_add(A.get_register(a_reg), c_tmp);
327  }
328  }
329 
330  C.get_register(c_reg) +=
331  c_tmp.segmented_sum_outer(ac_segbits, c_segment);
332 
333  ++bc_col;
334  } // c_segment
335  } // c_reg
336  }
337 
339  static RAJA_INLINE void multiply(left_type const& A,
340  right_type const& B,
341  result_type& C)
342  {
343  C = result_type(0);
344  self_type::multiply_accumulate(A, B, C);
345  }
346 };
347 
348 
349 } // namespace expt
350 } // namespace internal
351 } // namespace RAJA
352 
353 
354 #endif
RAJA header file defining SIMD/SIMT register operations.
Definition: TensorRegister.hpp:46
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
TensorLayout< 0, 1 > RowMajorLayout
Definition: TensorLayout.hpp:80
TensorLayout< 1, 0 > ColMajorLayout
Definition: TensorLayout.hpp:81
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE static RAJA_INLINE std::enable_if<(s_C_minor_dim_registers==0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:286
static RAJA_HOST_DEVICE RAJA_INLINE std::enable_if<(s_C_minor_dim_registers !=0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:249
RAJA_HOST_DEVICE static RAJA_INLINE std::enable_if<(s_C_minor_dim_registers==0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:131
static RAJA_HOST_DEVICE RAJA_INLINE std::enable_if<(s_C_minor_dim_registers !=0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:96
Definition: MatrixMatrixMultiply.hpp:36