20 #ifndef RAJA_pattern_tensor_internal_MatrixMatrixMultiply_HPP
21 #define RAJA_pattern_tensor_internal_MatrixMatrixMultiply_HPP
23 #include "camp/camp.hpp"
24 #include "RAJA/config.hpp"
35 template<
typename MATA,
typename MATB>
44 typename REGISTER_POLICY,
52 RAJA::expt::RowMajorLayout,
53 camp::idx_seq<N_SIZE, M_SIZE>>,
56 RAJA::expt::RowMajorLayout,
57 camp::idx_seq<M2_SIZE, O_SIZE>>>
60 static_assert(M_SIZE == M2_SIZE,
61 "Matrices are not compatible for multiplication");
66 camp::idx_seq<N_SIZE, M_SIZE>>;
71 camp::idx_seq<M_SIZE, O_SIZE>>;
76 camp::idx_seq<N_SIZE, O_SIZE>>;
80 static constexpr camp::idx_t s_elements_per_register =
81 left_type::s_elements_per_register;
82 static constexpr camp::idx_t s_A_minor_dim_registers =
83 left_type::s_minor_dim_registers;
84 static constexpr camp::idx_t s_B_minor_dim_registers =
85 right_type::s_minor_dim_registers;
86 static constexpr camp::idx_t s_C_minor_dim_registers =
87 result_type::s_minor_dim_registers;
93 template<
typename dummy =
void>
95 typename std::enable_if<(s_C_minor_dim_registers != 0), dummy>::type
100 #if defined(RAJA_ENABLE_VECTOR_STATS) && !defined(__CUDA_ARCH__)
101 RAJA::tensor_stats::num_matrix_mm_multacc_row_row++;
104 constexpr camp::idx_t num_bc_reg_per_row = s_C_minor_dim_registers;
107 for (camp::idx_t c_reg = 0; c_reg < result_type::s_num_registers; ++c_reg)
109 camp::idx_t bc_col_reg = c_reg % num_bc_reg_per_row;
110 camp::idx_t ac_row = c_reg / num_bc_reg_per_row;
113 for (camp::idx_t a_col = 0; a_col < M_SIZE; ++a_col)
115 camp::idx_t b_reg = a_col * num_bc_reg_per_row + bc_col_reg;
117 C.get_register(c_reg) =
119 .multiply_add(B.get_register(b_reg), C.get_register(c_reg));
128 template<
typename dummy =
void>
130 typename std::enable_if<(s_C_minor_dim_registers == 0), dummy>::type
135 constexpr camp::idx_t bc_segbits = result_type::s_segbits;
136 constexpr camp::idx_t a_segments_per_register = 1 << bc_segbits;
139 for (camp::idx_t ac_row = 0; ac_row < N_SIZE; ++ac_row)
141 camp::idx_t c_reg = ac_row / result_type::s_major_dim_per_register;
142 camp::idx_t c_segment = ac_row % result_type::s_major_dim_per_register;
146 for (camp::idx_t b_reg = 0; b_reg < right_type::s_num_registers; ++b_reg)
149 camp::idx_t a_segment = ac_row * right_type::s_num_registers + b_reg;
150 camp::idx_t a_reg = a_segment / a_segments_per_register;
151 camp::idx_t a_reg_segment = a_segment % a_segments_per_register;
153 auto a_tmp = A.get_register(a_reg).segmented_broadcast_outer(
154 bc_segbits, a_reg_segment);
159 c_tmp = a_tmp.multiply(B.get_register(b_reg));
163 c_tmp = a_tmp.multiply_add(B.get_register(b_reg), c_tmp);
167 C.get_register(c_reg) += c_tmp.segmented_sum_outer(bc_segbits, c_segment);
177 multiply_accumulate(A, B, C);
187 typename REGISTER_POLICY,
195 RAJA::expt::ColMajorLayout,
196 camp::idx_seq<N_SIZE, M_SIZE>>,
199 RAJA::expt::ColMajorLayout,
200 camp::idx_seq<M2_SIZE, O_SIZE>>>
207 camp::idx_seq<N_SIZE, M_SIZE>>,
211 camp::idx_seq<M2_SIZE, O_SIZE>>>;
213 static_assert(M_SIZE == M2_SIZE,
214 "Matrices are not compatible for multiplication");
219 camp::idx_seq<N_SIZE, M_SIZE>>;
224 camp::idx_seq<M_SIZE, O_SIZE>>;
229 camp::idx_seq<N_SIZE, O_SIZE>>;
233 static constexpr camp::idx_t s_elements_per_register =
234 left_type::s_elements_per_register;
235 static constexpr camp::idx_t s_A_minor_dim_registers =
236 left_type::s_minor_dim_registers;
237 static constexpr camp::idx_t s_B_minor_dim_registers =
238 right_type::s_minor_dim_registers;
239 static constexpr camp::idx_t s_C_minor_dim_registers =
240 result_type::s_minor_dim_registers;
246 template<
typename dummy =
void>
248 typename std::enable_if<(s_C_minor_dim_registers != 0), dummy>::type
254 #if defined(RAJA_ENABLE_VECTOR_STATS) && !defined(__CUDA_ARCH__)
255 RAJA::tensor_stats::num_matrix_mm_multacc_row_row++;
259 constexpr camp::idx_t num_ac_reg_per_col = s_C_minor_dim_registers;
262 for (camp::idx_t c_reg = 0; c_reg < result_type::s_num_registers; ++c_reg)
264 camp::idx_t ac_row_reg = c_reg % num_ac_reg_per_col;
265 camp::idx_t bc_col = c_reg / num_ac_reg_per_col;
268 for (camp::idx_t b_row = 0; b_row < M_SIZE; ++b_row)
270 camp::idx_t a_reg = b_row * num_ac_reg_per_col + ac_row_reg;
272 C.get_register(c_reg) =
274 .multiply_add(A.get_register(a_reg), C.get_register(c_reg));
283 template<
typename dummy =
void>
285 typename std::enable_if<(s_C_minor_dim_registers == 0), dummy>::type
290 constexpr camp::idx_t ac_segbits = result_type::s_segbits;
291 constexpr camp::idx_t b_segments_per_register = 1 << ac_segbits;
293 camp::idx_t bc_col = 0;
296 for (camp::idx_t c_reg = 0;
297 c_reg < N_SIZE / result_type::s_major_dim_per_register; ++c_reg)
301 for (camp::idx_t c_segment = 0;
302 c_segment < result_type::s_major_dim_per_register; ++c_segment)
308 for (camp::idx_t a_reg = 0; a_reg < right_type::s_num_registers;
313 camp::idx_t b_segment = bc_col * right_type::s_num_registers + a_reg;
314 camp::idx_t b_reg = b_segment / b_segments_per_register;
315 camp::idx_t b_reg_segment = b_segment % b_segments_per_register;
317 register_type b_tmp = B.get_register(b_reg).segmented_broadcast_outer(
318 ac_segbits, b_reg_segment);
322 c_tmp = b_tmp.multiply(A.get_register(a_reg));
326 c_tmp = b_tmp.multiply_add(A.get_register(a_reg), c_tmp);
330 C.get_register(c_reg) +=
331 c_tmp.segmented_sum_outer(ac_segbits, c_segment);
344 self_type::multiply_accumulate(A, B, C);
RAJA header file defining SIMD/SIMT register operations.
Definition: TensorRegister.hpp:46
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
TensorLayout< 0, 1 > RowMajorLayout
Definition: TensorLayout.hpp:80
TensorLayout< 1, 0 > ColMajorLayout
Definition: TensorLayout.hpp:81
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE static RAJA_INLINE std::enable_if<(s_C_minor_dim_registers==0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:286
static RAJA_HOST_DEVICE RAJA_INLINE void multiply(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:339
typename result_type::register_type register_type
Definition: MatrixMatrixMultiply.hpp:231
static RAJA_HOST_DEVICE RAJA_INLINE std::enable_if<(s_C_minor_dim_registers !=0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:249
static RAJA_HOST_DEVICE RAJA_INLINE void multiply(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:172
RAJA_HOST_DEVICE static RAJA_INLINE std::enable_if<(s_C_minor_dim_registers==0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:131
static RAJA_HOST_DEVICE RAJA_INLINE std::enable_if<(s_C_minor_dim_registers !=0), dummy >::type multiply_accumulate(left_type const &A, right_type const &B, result_type &C)
Definition: MatrixMatrixMultiply.hpp:96
typename result_type::register_type register_type
Definition: MatrixMatrixMultiply.hpp:78
Definition: MatrixMatrixMultiply.hpp:36