20 #ifndef RAJA_pattern_tensor_MatrixRegisterImpl_HPP
21 #define RAJA_pattern_tensor_MatrixRegisterImpl_HPP
23 #include "camp/camp.hpp"
24 #include "RAJA/config.hpp"
37 template<
typename REGISTER_POLICY,
46 camp::idx_seq<ROW_SIZE, COL_SIZE>>
48 TensorRegister<REGISTER_POLICY,
50 TensorLayout<ROW_ORD, COL_ORD>,
51 camp::idx_seq<ROW_SIZE, COL_SIZE>>>
57 camp::idx_seq<ROW_SIZE, COL_SIZE>>;
62 camp::idx_seq<ROW_SIZE, COL_SIZE>>>;
74 camp::idx_seq<ROW_SIZE, COL_SIZE>>;
79 camp::idx_seq<COL_SIZE, ROW_SIZE>>;
83 camp::idx_seq<ROW_SIZE, ROW_SIZE>>;
85 static constexpr camp::idx_t s_num_rows = ROW_SIZE;
86 static constexpr camp::idx_t s_num_columns = COL_SIZE;
89 static constexpr camp::idx_t s_elements_per_register =
93 static constexpr camp::idx_t s_num_registers =
94 (ROW_SIZE * COL_SIZE) / s_elements_per_register;
97 static_assert((ROW_SIZE * COL_SIZE) ==
98 s_num_registers * s_elements_per_register,
99 "MatrixRegister must be dimensioned to exactly fit an integer "
100 "number of registers");
104 static constexpr camp::idx_t s_shift_per_register = log_base2_t::value;
106 static constexpr camp::idx_t s_mask_per_register =
107 (1 << log_base2_t::value) - 1;
110 static constexpr camp::idx_t s_minor_dim_elements =
111 layout_type::is_row_major() ? s_num_columns : s_num_rows;
113 static constexpr camp::idx_t s_major_dim_elements =
114 layout_type::is_row_major() ? s_num_rows : s_num_columns;
119 static constexpr camp::idx_t s_minor_dim_registers =
120 s_minor_dim_elements / s_elements_per_register;
122 static_assert(s_minor_dim_registers > 0 || log_base2_t::is_exact,
123 "Minor dimension smaller than a vector need to be a power of "
126 static_assert(s_minor_dim_registers == 0 ||
127 (s_minor_dim_elements % s_elements_per_register == 0),
128 "Minor dimensions greater than a vector length must be an "
129 "integer number of vectors");
132 static constexpr camp::idx_t s_major_dim_per_register =
133 s_elements_per_register / s_minor_dim_elements;
135 static constexpr camp::idx_t s_segbits =
139 template<
typename IDX>
143 return layout_type::is_row_major()
144 ? (row * IDX(COL_SIZE) + col) >> IDX(s_shift_per_register)
145 : (col * IDX(ROW_SIZE) + row) >> IDX(s_shift_per_register);
148 template<
typename IDX>
149 RAJA_INLINE
RAJA_HOST_DEVICE constexpr
static auto to_lane(IDX row, IDX col)
152 return layout_type::is_row_major()
153 ? (row * IDX(COL_SIZE) + col) & IDX(s_mask_per_register)
154 : (col * IDX(ROW_SIZE) + row) & IDX(s_mask_per_register);
157 using base_type::m_registers;
187 template<camp::
idx_t STRIDE_ONE_DIM>
190 return (STRIDE_ONE_DIM == 0 && layout_type::is_column_major()) ||
191 (STRIDE_ONE_DIM == 1 && layout_type::is_row_major());
202 return dim == 0 ? ROW_SIZE : COL_SIZE;
214 this->broadcast(value);
226 template<
typename T2,
typename L,
typename RP>
229 return matrix_multiply(
y);
236 template<
typename T2,
typename RP>
239 return right_multiply_vector(
y);
243 template<
typename REF_TYPE>
246 template<
typename REF_TYPE>
249 RefBridge<REF_TYPE>::load_ref(*
this, ref);
253 template<
typename REF_TYPE>
256 RefBridge<REF_TYPE>::store_ref(*
this, ref);
260 template<
typename POINTER_TYPE,
263 camp::idx_t STRIDE_ONE_DIM>
265 RAJA::internal::expt::
266 TensorRef<POINTER_TYPE, INDEX_TYPE, TENSOR_SIZE, 2, STRIDE_ONE_DIM>>
270 TensorRef<POINTER_TYPE, INDEX_TYPE, TENSOR_SIZE, 2, STRIDE_ONE_DIM>;
285 if (
self.is_ref_packed<STRIDE_ONE_DIM>())
329 if (
self.is_ref_packed<STRIDE_ONE_DIM>())
361 template<
typename POINTER_TYPE,
364 INDEX_TYPE StrideInt1,
365 INDEX_TYPE StrideInt2,
366 INDEX_TYPE BeginInt1,
367 INDEX_TYPE BeginInt2,
370 camp::idx_t STRIDE_ONE_DIM>
375 camp::int_seq<INDEX_TYPE, StrideInt1, StrideInt2>,
376 camp::int_seq<INDEX_TYPE, BeginInt1, BeginInt2>,
377 camp::int_seq<INDEX_TYPE, SizeInt1, SizeInt2>,
385 camp::int_seq<INDEX_TYPE, StrideInt1, StrideInt2>,
386 camp::int_seq<INDEX_TYPE, BeginInt1, BeginInt2>,
387 camp::int_seq<INDEX_TYPE, SizeInt1, SizeInt2>,
399 auto ptr = ref.m_pointer + ref.m_tile.m_begin[0] * ref.m_stride[0] +
400 ref.m_tile.m_begin[1] * ref.m_stride[1];
403 if (
self.is_ref_packed<STRIDE_ONE_DIM>())
408 self.load_packed(ptr, ref.m_stride[0], ref.m_stride[1]);
413 self.load_packed_nm(ptr, ref.m_stride[0], ref.m_stride[1],
414 ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
423 self.load_strided(ptr, ref.m_stride[0], ref.m_stride[1]);
428 self.load_strided_nm(ptr, ref.m_stride[0], ref.m_stride[1],
429 ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
443 auto ptr = ref.m_pointer + ref.m_tile.m_begin[0] * ref.m_stride[0] +
444 ref.m_tile.m_begin[1] * ref.m_stride[1];
447 if (
self.is_ref_packed<STRIDE_ONE_DIM>())
452 self.store_packed(ptr, ref.m_stride[0], ref.m_stride[1]);
457 self.store_packed_nm(ptr, ref.m_stride[0], ref.m_stride[1],
458 ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
467 self.store_strided(ptr, ref.m_stride[0], ref.m_stride[1]);
472 self.store_strided_nm(ptr, ref.m_stride[0], ref.m_stride[1],
473 ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
496 if ((layout_type::is_row_major() && (row_stride == COL_SIZE)) ||
497 (layout_type::is_column_major() && (col_stride == ROW_SIZE)))
500 for (camp::idx_t reg = 0; reg < s_num_registers; ++reg)
502 m_registers[reg].
load_packed(ptr + reg * s_elements_per_register);
506 else if (layout_type::is_row_major())
510 if (s_minor_dim_registers)
513 for (camp::idx_t row = 0; row < ROW_SIZE; ++row)
515 for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
519 row * row_stride + colreg * s_elements_per_register;
531 return load_strided(ptr, row_stride, col_stride);
538 if (s_minor_dim_registers)
542 for (camp::idx_t col = 0; col < COL_SIZE; ++col)
544 for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
548 col * col_stride + rowreg * s_elements_per_register;
550 m_registers[reg].load_packed(ptr + offset);
560 return load_strided(ptr, row_stride, col_stride);
578 if (layout_type::is_row_major())
581 if (s_minor_dim_registers)
583 for (camp::idx_t i = 0; i < s_num_registers; ++i)
586 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
588 s_elements_per_register * (i - (row * s_minor_dim_registers));
589 m_registers[i].load_strided(ptr + row * row_stride + col * col_stride,
596 for (camp::idx_t i = 0; i < s_num_registers; ++i)
599 ptr + i * row_stride * s_major_dim_per_register;
600 m_registers[i].segmented_load(ptr_i, s_segbits, col_stride,
611 if (s_minor_dim_registers)
613 for (camp::idx_t i = 0; i < s_num_registers; ++i)
616 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
618 s_elements_per_register * (i - (col * s_minor_dim_registers));
620 m_registers[i].load_strided(ptr + row * row_stride + col * col_stride,
627 for (camp::idx_t i = 0; i < s_num_registers; ++i)
630 ptr + i * col_stride * s_major_dim_per_register;
631 m_registers[i].segmented_load(ptr_i, s_segbits, row_stride,
653 if (layout_type::is_row_major())
657 if (s_minor_dim_registers)
660 for (camp::idx_t row = 0; row < num_rows; ++row)
662 for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
665 camp::idx_t reg = row * s_minor_dim_registers + colreg;
667 camp::idx_t col0 = colreg * s_elements_per_register;
668 camp::idx_t offset = row * row_stride + col0;
671 if (col0 + s_elements_per_register <= num_cols)
679 m_registers[reg].load_packed_n(ptr + offset, num_cols - col0);
682 for (camp::idx_t i = colreg + 1; i < s_minor_dim_registers; ++i)
694 for (camp::idx_t row = num_rows; row < ROW_SIZE; ++row)
696 for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
699 camp::idx_t reg = row * s_minor_dim_registers + colreg;
709 return load_strided_nm(ptr, row_stride, col_stride, num_rows, num_cols);
717 if (s_minor_dim_registers)
720 for (camp::idx_t col = 0; col < num_cols; ++col)
722 for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
725 camp::idx_t reg = col * s_minor_dim_registers + rowreg;
727 camp::idx_t row0 = rowreg * s_elements_per_register;
728 camp::idx_t offset = col * col_stride + row0;
731 if (row0 + s_elements_per_register <= num_rows)
733 m_registers[reg].load_packed(ptr + offset);
739 m_registers[reg].load_packed_n(ptr + offset, num_rows - row0);
742 for (camp::idx_t i = rowreg + 1; i < s_minor_dim_registers; ++i)
753 for (camp::idx_t col = num_cols; col < COL_SIZE; ++col)
755 for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
758 camp::idx_t reg = col * s_minor_dim_registers + rowreg;
769 return load_strided_nm(ptr, row_stride, col_stride, num_rows, num_cols);
789 if (layout_type::is_row_major())
792 if (s_minor_dim_registers)
795 for (camp::idx_t i = 0; i < s_num_registers; ++i)
798 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
806 s_elements_per_register * (i - (row * s_minor_dim_registers));
809 camp::idx_t reg_num_cols = s_elements_per_register;
810 if (reg_num_cols + col > num_cols)
812 reg_num_cols = num_cols - col;
813 m_registers[i].load_strided_n(ptr + row * row_stride +
815 col_stride, reg_num_cols);
819 m_registers[i].load_strided(
820 ptr + row * row_stride + col * col_stride, col_stride);
829 for (camp::idx_t i = 0; i < s_num_registers; ++i)
832 camp::idx_t reg_num_rows = num_rows - i * s_major_dim_per_register;
833 reg_num_rows = reg_num_rows > s_major_dim_per_register
834 ? s_major_dim_per_register
838 ptr + i * row_stride * s_major_dim_per_register;
839 m_registers[i].segmented_load_nm(ptr_i, s_segbits, col_stride,
840 row_stride, num_cols, reg_num_rows);
850 if (s_minor_dim_registers)
852 for (camp::idx_t i = 0; i < s_num_registers; ++i)
855 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
863 s_elements_per_register * (i - (col * s_minor_dim_registers));
865 camp::idx_t reg_num_rows = s_elements_per_register;
866 if (reg_num_rows + row > num_rows)
868 reg_num_rows = num_rows - row;
869 m_registers[i].load_strided_n(ptr + row * row_stride +
871 row_stride, reg_num_rows);
875 m_registers[i].load_strided(
876 ptr + row * row_stride + col * col_stride, row_stride);
884 for (camp::idx_t i = 0; i < s_num_registers; ++i)
887 camp::idx_t reg_num_cols = num_cols - i * s_major_dim_per_register;
888 reg_num_cols = reg_num_cols > s_major_dim_per_register
889 ? s_major_dim_per_register
893 ptr + i * col_stride * s_major_dim_per_register;
894 m_registers[i].segmented_load_nm(ptr_i, s_segbits, row_stride,
895 col_stride, num_rows, reg_num_cols);
914 int col_stride)
const
918 if ((layout_type::is_row_major() && (row_stride == COL_SIZE)) ||
919 (layout_type::is_column_major() && (col_stride == ROW_SIZE)))
922 for (camp::idx_t reg = 0; reg < s_num_registers; ++reg)
924 m_registers[reg].
store_packed(ptr + reg * s_elements_per_register);
928 else if (layout_type::is_row_major())
932 if (s_minor_dim_registers)
934 for (camp::idx_t i = 0; i < s_num_registers; ++i)
937 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
939 s_elements_per_register * (i - (row * s_minor_dim_registers));
940 m_registers[i].store_packed(ptr + row * row_stride +
947 store_strided(ptr, row_stride, col_stride);
954 if (s_minor_dim_registers)
956 for (camp::idx_t i = 0; i < s_num_registers; ++i)
959 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
961 s_elements_per_register * (i - (col * s_minor_dim_registers));
962 m_registers[i].store_packed(ptr + row * row_stride +
969 store_strided(ptr, row_stride, col_stride);
986 int col_stride)
const
990 if (layout_type::is_row_major())
993 if (s_minor_dim_registers)
995 for (camp::idx_t i = 0; i < s_num_registers; ++i)
998 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1000 s_elements_per_register * (i - (row * s_minor_dim_registers));
1001 m_registers[i].store_strided(
1002 ptr + row * row_stride + col * col_stride, col_stride);
1008 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1010 element_type* ptr_i = ptr + i * row_stride * s_major_dim_per_register;
1011 m_registers[i].segmented_store(ptr_i, s_segbits, col_stride,
1021 if (s_minor_dim_registers)
1023 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1026 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1028 s_elements_per_register * (i - (col * s_minor_dim_registers));
1029 m_registers[i].store_strided(
1030 ptr + row * row_stride + col * col_stride, row_stride);
1036 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1038 element_type* ptr_i = ptr + i * col_stride * s_major_dim_per_register;
1039 m_registers[i].segmented_store(ptr_i, s_segbits, row_stride,
1062 if (layout_type::is_row_major())
1066 if (s_minor_dim_registers)
1069 for (camp::idx_t row = 0; row < num_rows; ++row)
1071 for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
1074 camp::idx_t reg = row * s_minor_dim_registers + colreg;
1076 camp::idx_t col0 = colreg * s_elements_per_register;
1077 camp::idx_t offset = row * row_stride + col0;
1080 if (col0 + s_elements_per_register <= num_cols)
1088 m_registers[reg].store_packed_n(ptr + offset, num_cols - col0);
1099 return store_strided_nm(ptr, row_stride, col_stride, num_rows,
1108 if (s_minor_dim_registers)
1111 for (camp::idx_t col = 0; col < num_cols; ++col)
1113 for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
1116 camp::idx_t reg = col * s_minor_dim_registers + rowreg;
1118 camp::idx_t row0 = rowreg * s_elements_per_register;
1119 camp::idx_t offset = col * col_stride + row0;
1122 if (row0 + s_elements_per_register <= num_rows)
1124 m_registers[reg].store_packed(ptr + offset);
1130 m_registers[reg].store_packed_n(ptr + offset, num_rows - row0);
1142 return store_strided_nm(ptr, row_stride, col_stride, num_rows,
1164 if (layout_type::is_row_major())
1167 if (s_minor_dim_registers)
1170 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1173 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1177 s_elements_per_register * (i - (row * s_minor_dim_registers));
1180 camp::idx_t reg_num_cols = s_elements_per_register;
1181 if (reg_num_cols + col > num_cols)
1183 reg_num_cols = num_cols - col;
1184 m_registers[i].store_strided_n(ptr + row * row_stride +
1186 col_stride, reg_num_cols);
1190 m_registers[i].store_strided(
1191 ptr + row * row_stride + col * col_stride, col_stride);
1200 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1203 camp::idx_t reg_num_rows = num_rows - i * s_major_dim_per_register;
1204 reg_num_rows = reg_num_rows > s_major_dim_per_register
1205 ? s_major_dim_per_register
1208 element_type* ptr_i = ptr + i * row_stride * s_major_dim_per_register;
1209 m_registers[i].segmented_store_nm(ptr_i, s_segbits, col_stride,
1210 row_stride, num_cols, reg_num_rows);
1220 if (s_minor_dim_registers)
1222 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1225 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1229 s_elements_per_register * (i - (col * s_minor_dim_registers));
1231 camp::idx_t reg_num_rows = s_elements_per_register;
1232 if (reg_num_rows + row > num_rows)
1234 reg_num_rows = num_rows - row;
1235 m_registers[i].store_strided_n(ptr + row * row_stride +
1237 row_stride, reg_num_rows);
1241 m_registers[i].store_strided(
1242 ptr + row * row_stride + col * col_stride, row_stride);
1250 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1253 camp::idx_t reg_num_cols = num_cols - i * s_major_dim_per_register;
1254 reg_num_cols = reg_num_cols > s_major_dim_per_register
1255 ? s_major_dim_per_register
1258 element_type* ptr_i = ptr + i * col_stride * s_major_dim_per_register;
1259 m_registers[i].segmented_store_nm(ptr_i, s_segbits, row_stride,
1260 col_stride, num_rows, reg_num_cols);
1277 if (layout_type::is_row_major())
1280 if (s_minor_dim_registers)
1283 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1286 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1290 s_elements_per_register * (i - (row * s_minor_dim_registers));
1293 camp::idx_t reg_num_cols = s_elements_per_register;
1294 if (reg_num_cols + col > num_cols)
1296 reg_num_cols = num_cols - col;
1297 result.m_registers[i] =
1298 m_registers[i].divide_n(mat.m_registers[i], reg_num_cols);
1302 result.m_registers[i] = m_registers[i].divide(mat.m_registers[i]);
1311 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1314 camp::idx_t reg_num_rows = num_rows - i * s_major_dim_per_register;
1315 reg_num_rows = reg_num_rows > s_major_dim_per_register
1316 ? s_major_dim_per_register
1319 result.m_registers[i] = m_registers[i].segmented_divide_nm(
1320 mat.m_registers[i], s_segbits, num_cols, reg_num_rows);
1330 if (s_minor_dim_registers)
1332 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1335 i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1339 s_elements_per_register * (i - (col * s_minor_dim_registers));
1341 camp::idx_t reg_num_rows = s_elements_per_register;
1342 if (reg_num_rows + row > num_rows)
1344 reg_num_rows = num_rows - row;
1345 result.m_registers[i] =
1346 m_registers[i].divide_n(mat.m_registers[i], reg_num_rows);
1350 result.m_registers[i] = m_registers[i].divide(mat.m_registers[i]);
1358 for (camp::idx_t i = 0; i < s_num_registers; ++i)
1361 camp::idx_t reg_num_cols = num_cols - i * s_major_dim_per_register;
1362 reg_num_cols = reg_num_cols > s_major_dim_per_register
1363 ? s_major_dim_per_register
1366 result.m_registers[i] = m_registers[i].segmented_divide_nm(
1367 mat.m_registers[i], s_segbits, num_rows, reg_num_cols);
1384 transpose_type transpose()
const {
1386 static constexpr camp::idx_t num_elem = register_type::s_num_elem;
1396 self_type result = *
this;
1398 if(s_minor_dim_registers == 0){
1422 for(camp::idx_t lvl = 0; (1<<lvl) < num_elem;++ lvl){
1426 camp::idx_t skip_bits = 0;
1427 if(transpose_type::s_major_dim_per_register <= 1){
1430 camp::idx_t skip_reg = (1<<skip_bits)*s_minor_dim_registers;
1432 auto const &vals = result.m_registers;
1435 for(camp::idx_t major = 0;major < s_major_dim_elements;++ major){
1436 if(((major>>skip_bits)&0x1) == 0){
1437 for(camp::idx_t i = major*s_minor_dim_registers;i < (major+1)*s_minor_dim_registers;++ i){
1438 tmp.m_registers[i] = vals[i].transpose_shuffle_left(lvl, vals[i+skip_reg]);
1443 for(camp::idx_t i = major*s_minor_dim_registers;i < (major+1)*s_minor_dim_registers;++ i){
1445 tmp.m_registers[i] = vals[i-skip_reg].transpose_shuffle_right(lvl, vals[i]);
1456 for(camp::idx_t lvl = 0; (1<<lvl) < s_minor_dim_registers;++ lvl){
1459 camp::idx_t skip_reg = 1<<lvl;
1461 auto const &vals = result.m_registers;
1464 for(camp::idx_t major = 0;major < s_major_dim_elements;++ major){
1465 if(((major>>skip_bits)&0x1) == 0){
1466 for(camp::idx_t minor = 0;minor < self_type::s_minor_dim_registers;++ minor){
1469 camp::idx_t xy_select = (minor >> lvl) & 0x1;
1471 camp::idx_t reg = major*s_minor_dim_registers + minor;
1472 camp::idx_t reg_x = major*s_minor_dim_registers + minor;
1473 camp::idx_t reg_y = (major+skip_reg)*s_minor_dim_registers + minor;
1476 tmp.m_registers[reg] =
1477 xy_select == 0 ? result.m_registers[reg_x] : result.m_registers[reg_y];
1491 transpose_type *tptr =
reinterpret_cast<transpose_type*
>(&result);
1507 void inplace_transpose() {
1508 *
this = transpose();
1521 transpose_tensor_type
const &transpose_by_type()
const {
1522 return reinterpret_cast<transpose_tensor_type
const &
>(*this);
1534 return right_multiply_vector_accumulate(v, result);
1546 return left_multiply_vector_accumulate(v, result);
1563 if (layout_type::is_row_major())
1567 if (s_minor_dim_registers == 0)
1572 auto vv = v.get_register(0).segmented_broadcast_inner(s_segbits, 0);
1577 for (camp::idx_t outseg = 0; outseg < s_num_registers; ++outseg)
1581 camp::idx_t result_reg = outseg >> s_segbits;
1584 camp::idx_t result_seg = outseg - (result_reg << s_segbits);
1588 m_registers[outseg].segmented_dot(s_segbits, result_seg, vv);
1591 result.get_register(result_reg) += value;
1599 camp::idx_t reg = 0;
1601 for (camp::idx_t row = 0; row < s_num_rows; ++row)
1607 for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
1611 m_registers[reg].multiply_add(v.get_register(colreg), rowsum);
1617 auto value = result.get(row) + rowsum.sum();
1618 result.set(value, row);
1628 if (s_minor_dim_registers == 0)
1631 auto& mv = result.get_register(0);
1635 for (camp::idx_t m_reg = 0; m_reg < s_num_registers; ++m_reg)
1637 camp::idx_t v_reg = m_reg >> s_segbits;
1638 camp::idx_t v_seg = m_reg & ((1 << s_segbits) - 1);
1641 v.get_register(v_reg).segmented_broadcast_outer(s_segbits, v_seg);
1642 mv = m_registers[m_reg].multiply_add(v_tmp, mv);
1646 mv = mv.segmented_sum_outer(s_segbits, 0);
1653 camp::idx_t reg = 0;
1655 for (camp::idx_t col = 0; col < s_num_columns; ++col)
1663 for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
1666 auto& mv = result.get_register(rowreg);
1667 mv = m_registers[reg].multiply_add(v_col, mv);
1691 if (layout_type::is_row_major())
1695 if (s_minor_dim_registers == 0)
1697 auto& vm = result.get_register(0);
1701 for (camp::idx_t m_reg = 0; m_reg < s_num_registers; ++m_reg)
1703 camp::idx_t v_reg = m_reg >> s_segbits;
1704 camp::idx_t v_seg = m_reg & ((1 << s_segbits) - 1);
1707 v.get_register(v_reg).segmented_broadcast_outer(s_segbits, v_seg);
1708 vm = m_registers[m_reg].multiply_add(v_tmp, vm);
1712 vm = vm.segmented_sum_outer(s_segbits, 0);
1719 camp::idx_t reg = 0;
1721 for (camp::idx_t row = 0; row < s_num_rows; ++row)
1725 for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
1728 result.get_register(colreg) = m_registers[reg].multiply_add(
1729 lhs_bcat, result.get_register(colreg));
1743 if (s_minor_dim_registers == 0)
1748 auto vv = v.get_register(0).segmented_broadcast_inner(s_segbits, 0);
1753 for (camp::idx_t outseg = 0; outseg < s_num_registers; ++outseg)
1757 camp::idx_t result_reg = outseg >> s_segbits;
1760 camp::idx_t result_seg = outseg - (result_reg << s_segbits);
1764 m_registers[outseg].segmented_dot(s_segbits, result_seg, vv);
1767 result.get_register(result_reg) += value;
1774 camp::idx_t reg = 0;
1776 for (camp::idx_t col = 0; col < s_num_columns; ++col)
1782 for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
1785 m_registers[reg].multiply_add(v.get_register(rowreg), colsum);
1791 auto value = result.get(col) + colsum.sum();
1792 result.set(value, col);
1805 template<
typename RMAT>
1807 MatrixMatrixMultiplyHelper<self_type, RMAT>::result_type
1820 template<
typename RMAT>
1822 MatrixMatrixMultiplyHelper<self_type, RMAT>::result_type
1827 RMAT>::result_type
const& C)
const
1832 self_type, RMAT>::multiply_accumulate(*
this, B, res);
1839 template<
typename ACCMAT,
typename RMAT>
1842 RMAT
const& B)
const
1845 self_type, RMAT>::multiply_accumulate(*
this, B, acc);
1853 m_registers[to_register(row, col)].
set(val, to_lane(row, col));
1862 return m_registers[to_register(row, col)].get(to_lane(row, col));
1869 camp::idx_t segbits,
1870 camp::idx_t segment)
const
1875 camp::idx_t num_rows = register_type::s_num_elem >> segbits;
1876 camp::idx_t num_repeats = 1 << segbits;
1878 camp::idx_t col0 = (starting_column + num_rows * segment) % s_num_columns;
1879 camp::idx_t row0 = num_rows * segment;
1881 for (camp::idx_t i = 0; i < num_rows; ++i)
1883 camp::idx_t col = (col0 + i) % s_num_columns;
1884 camp::idx_t row = row0 + i;
1885 auto value =
get(row, col);
1886 for (camp::idx_t j = 0; j < num_repeats; ++j)
1888 result.set(value, (i << segbits) + j);
1903 std::string s =
"Matrix(" + std::to_string(s_num_rows) +
"x" +
1904 std::to_string(s_num_columns);
1914 for (camp::idx_t r = 0; r < s_num_rows; ++r)
1925 for (camp::idx_t c = 0; c < s_num_columns; ++c)
1931 s += std::to_string(this->
get(r, c));
RAJA header file defining a bit masking operator.
RAJA header file defining SIMD/SIMT register operations.
RAJA header file defining SIMD/SIMT register operations.
Definition: RegisterBase.hpp:39
Definition: MatrixRegisterImpl.hpp:52
RAJA_HOST_DEVICE RAJA_INLINE RAJA::internal::expt::MatrixMatrixMultiplyHelper< self_type, RMAT >::result_type matrix_multiply_add(RMAT const &B, typename RAJA::internal::expt::MatrixMatrixMultiplyHelper< self_type, RMAT >::result_type const &C) const
Definition: MatrixRegisterImpl.hpp:1823
RAJA_HOST_DEVICE constexpr RAJA_INLINE TensorRegister()
Definition: MatrixRegisterImpl.hpp:163
RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(element_type value)
Set entire vector to a single scalar value.
Definition: MatrixRegisterImpl.hpp:212
RAJA_INLINE std::string to_string(bool one_line=false) const
Converts to matrix to a string.
Definition: MatrixRegisterImpl.hpp:1901
RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_ref(REF_TYPE &ref) const
Definition: MatrixRegisterImpl.hpp:254
REGISTER_POLICY register_policy
Definition: MatrixRegisterImpl.hpp:66
RAJA_HOST_DEVICE RAJA_INLINE register_type extract_diagonal_register(camp::idx_t starting_column, camp::idx_t segbits, camp::idx_t segment) const
Definition: MatrixRegisterImpl.hpp:1868
RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_strided_nm(element_type *ptr, int row_stride, int col_stride, int num_rows, int num_cols) const
Definition: MatrixRegisterImpl.hpp:1156
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_packed(element_type *ptr, int row_stride, int col_stride) const
Definition: MatrixRegisterImpl.hpp:912
self_type operator*(SquareMatrixRegister< T2, L, RP > const &y) const
Definition: MatrixRegisterImpl.hpp:227
RAJA_HOST_DEVICE RAJA_INLINE self_type & set(element_type val, int row, int col)
Definition: MatrixRegisterImpl.hpp:1851
RAJA_HOST_DEVICE RAJA_INLINE RAJA::internal::expt::MatrixMatrixMultiplyHelper< self_type, RMAT >::result_type matrix_multiply(RMAT const &mat) const
Definition: MatrixRegisterImpl.hpp:1808
VectorRegister< T2, RP > operator*(VectorRegister< T2, RP > const &y) const
Definition: MatrixRegisterImpl.hpp:237
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_packed_nm(element_type const *ptr, int row_stride, int col_stride, int num_rows, int num_cols)
Definition: MatrixRegisterImpl.hpp:646
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE column_vector_type right_multiply_vector_accumulate(row_vector_type const &v, column_vector_type result) const
Definition: MatrixRegisterImpl.hpp:1558
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_strided(element_type const *ptr, int row_stride, int col_stride)
Definition: MatrixRegisterImpl.hpp:573
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type divide_nm(self_type const &mat, int num_rows, int num_cols) const
Definition: MatrixRegisterImpl.hpp:1272
RAJA_HOST_DEVICE static constexpr RAJA_INLINE bool is_ref_packed()
Definition: MatrixRegisterImpl.hpp:188
RAJA_HOST_DEVICE RAJA_INLINE column_vector_type right_multiply_vector(row_vector_type v) const
Definition: MatrixRegisterImpl.hpp:1531
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_ref(REF_TYPE const &ref)
Definition: MatrixRegisterImpl.hpp:247
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_packed(element_type const *ptr, int row_stride, int col_stride)
Definition: MatrixRegisterImpl.hpp:491
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_strided(element_type *ptr, int row_stride, int col_stride) const
Definition: MatrixRegisterImpl.hpp:984
RAJA_HOST_DEVICE RAJA_INLINE row_vector_type left_multiply_vector(column_vector_type v) const
Definition: MatrixRegisterImpl.hpp:1543
RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(self_type const &c)
Definition: MatrixRegisterImpl.hpp:221
RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_packed_nm(element_type *ptr, int row_stride, int col_stride, int num_rows, int num_cols) const
Definition: MatrixRegisterImpl.hpp:1054
RAJA_HOST_DEVICE RAJA_INLINE ~TensorRegister()
Definition: MatrixRegisterImpl.hpp:178
RAJA_HOST_DEVICE static constexpr RAJA_INLINE camp::idx_t s_dim_elem(camp::idx_t dim)
Definition: MatrixRegisterImpl.hpp:200
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE row_vector_type left_multiply_vector_accumulate(column_vector_type const &v, row_vector_type result) const
Definition: MatrixRegisterImpl.hpp:1687
T element_type
Definition: MatrixRegisterImpl.hpp:67
RAJA_HOST_DEVICE RAJA_INLINE element_type get(int row, int col) const
Definition: MatrixRegisterImpl.hpp:1860
RAJA_HOST_DEVICE RAJA_INLINE void matrix_multiply_accumulate(ACCMAT &acc, RMAT const &B) const
Definition: MatrixRegisterImpl.hpp:1840
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_strided_nm(element_type const *ptr, int row_stride, int col_stride, int num_rows, int num_cols)
Definition: MatrixRegisterImpl.hpp:782
RAJA_HOST_DEVICE RAJA_INLINE TensorRegister(element_type c)
Definition: MatrixRegisterImpl.hpp:168
RAJA_INLINE RAJA_HOST_DEVICE TensorRegister(self_type const &c)
Definition: MatrixRegisterImpl.hpp:173
Definition: TensorRegister.hpp:46
Definition: TensorRegisterBase.hpp:105
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_SUPPRESS_HD_WARN
Definition: macros.hpp:68
TensorTileSize
Definition: TensorRef.hpp:234
@ TENSOR_FULL
Definition: TensorRef.hpp:236
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
Definition: BitMask.hpp:30
Definition: TensorLayout.hpp:35
RAJA::expt::TensorRegister< REGISTER_POLICY, T, TensorLayout< ROW_ORD, COL_ORD >, camp::idx_seq< ROW_SIZE, COL_SIZE > >::RefBridge< RAJA::internal::expt::StaticTensorRef< POINTER_TYPE, INDEX_TYPE, TENSOR_SIZE, camp::int_seq< INDEX_TYPE, StrideInt1, StrideInt2 >, camp::int_seq< INDEX_TYPE, BeginInt1, BeginInt2 >, camp::int_seq< INDEX_TYPE, SizeInt1, SizeInt2 >, STRIDE_ONE_DIM > >::load_ref RAJA_INLINE static RAJA_HOST_DEVICE void load_ref(self_type &self, RefType const &ref)
Performs load specified by TensorRef object.
Definition: MatrixRegisterImpl.hpp:396
RAJA::expt::TensorRegister< REGISTER_POLICY, T, TensorLayout< ROW_ORD, COL_ORD >, camp::idx_seq< ROW_SIZE, COL_SIZE > >::RefBridge< RAJA::internal::expt::StaticTensorRef< POINTER_TYPE, INDEX_TYPE, TENSOR_SIZE, camp::int_seq< INDEX_TYPE, StrideInt1, StrideInt2 >, camp::int_seq< INDEX_TYPE, BeginInt1, BeginInt2 >, camp::int_seq< INDEX_TYPE, SizeInt1, SizeInt2 >, STRIDE_ONE_DIM > >::store_ref RAJA_INLINE static RAJA_HOST_DEVICE void store_ref(self_type const &self, RefType &ref)
Performs load specified by TensorRef object.
Definition: MatrixRegisterImpl.hpp:440
RAJA_INLINE static RAJA_HOST_DEVICE void store_ref(self_type const &self, RefType &ref)
Performs load specified by TensorRef object.
Definition: MatrixRegisterImpl.hpp:322
RAJA_INLINE static RAJA_HOST_DEVICE void load_ref(self_type &self, RefType const &ref)
Performs load specified by TensorRef object.
Definition: MatrixRegisterImpl.hpp:278
Definition: MatrixMatrixMultiply.hpp:36
Definition: TensorRef.hpp:472
Definition: TensorRef.hpp:426
index_type m_stride[NUM_DIMS]
Definition: TensorRef.hpp:442
pointer_type m_pointer
Definition: TensorRef.hpp:441
tile_type m_tile
Definition: TensorRef.hpp:443
index_type m_begin[NUM_DIMS]
Definition: TensorRef.hpp:246
index_type m_size[NUM_DIMS]
Definition: TensorRef.hpp:247