RAJA
RAJA provides a collection of platform portability abstractions for C++ HPC applications.
MatrixRegisterImpl.hpp
Go to the documentation of this file.
1 
11 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
12 // Copyright (c) Lawrence Livermore National Security, LLC and other
13 // RAJA Project Developers. See top-level LICENSE and COPYRIGHT
14 // files for dates and other details. No copyright assignment is required
15 // to contribute to RAJA.
16 //
17 // SPDX-License-Identifier: (BSD-3-Clause)
18 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~//
19 
20 #ifndef RAJA_pattern_tensor_MatrixRegisterImpl_HPP
21 #define RAJA_pattern_tensor_MatrixRegisterImpl_HPP
22 
23 #include "camp/camp.hpp"
24 #include "RAJA/config.hpp"
27 #include "RAJA/util/BitMask.hpp"
28 
29 namespace RAJA
30 {
31 namespace expt
32 {
33 
34 /*
35  * 2D (Matrix) specialization of TensorRegister
36  */
37 template<typename REGISTER_POLICY,
38  typename T,
39  camp::idx_t ROW_ORD,
40  camp::idx_t COL_ORD,
41  camp::idx_t ROW_SIZE,
42  camp::idx_t COL_SIZE>
43 class TensorRegister<REGISTER_POLICY,
44  T,
45  TensorLayout<ROW_ORD, COL_ORD>,
46  camp::idx_seq<ROW_SIZE, COL_SIZE>>
48  TensorRegister<REGISTER_POLICY,
49  T,
50  TensorLayout<ROW_ORD, COL_ORD>,
51  camp::idx_seq<ROW_SIZE, COL_SIZE>>>
52 {
53 public:
54  using self_type = TensorRegister<REGISTER_POLICY,
55  T,
57  camp::idx_seq<ROW_SIZE, COL_SIZE>>;
59  TensorRegister<REGISTER_POLICY,
60  T,
62  camp::idx_seq<ROW_SIZE, COL_SIZE>>>;
66  using register_policy = REGISTER_POLICY;
67  using element_type = T;
69 
71  TensorRegister<REGISTER_POLICY,
72  T,
74  camp::idx_seq<ROW_SIZE, COL_SIZE>>;
75 
76  using transpose_type = TensorRegister<REGISTER_POLICY,
77  T,
79  camp::idx_seq<COL_SIZE, ROW_SIZE>>;
80  using product_type = TensorRegister<REGISTER_POLICY,
81  T,
83  camp::idx_seq<ROW_SIZE, ROW_SIZE>>;
84 
85  static constexpr camp::idx_t s_num_rows = ROW_SIZE;
86  static constexpr camp::idx_t s_num_columns = COL_SIZE;
87 
88 
89  static constexpr camp::idx_t s_elements_per_register =
91 
92  // number of registers to hold entire matrix
93  static constexpr camp::idx_t s_num_registers =
94  (ROW_SIZE * COL_SIZE) / s_elements_per_register;
95 
96  // We only allow matrix sizes that exactly fit in some number of registers
97  static_assert((ROW_SIZE * COL_SIZE) ==
98  s_num_registers * s_elements_per_register,
99  "MatrixRegister must be dimensioned to exactly fit an integer "
100  "number of registers");
101 
103 
104  static constexpr camp::idx_t s_shift_per_register = log_base2_t::value;
105 
106  static constexpr camp::idx_t s_mask_per_register =
107  (1 << log_base2_t::value) - 1;
108 
109 
110  static constexpr camp::idx_t s_minor_dim_elements =
111  layout_type::is_row_major() ? s_num_columns : s_num_rows;
112 
113  static constexpr camp::idx_t s_major_dim_elements =
114  layout_type::is_row_major() ? s_num_rows : s_num_columns;
115 
116  // number of (full) registers that span the minor dim
117  // if a single register is split across multiple rows or columns, then
118  // this is 0
119  static constexpr camp::idx_t s_minor_dim_registers =
120  s_minor_dim_elements / s_elements_per_register;
121 
122  static_assert(s_minor_dim_registers > 0 || log_base2_t::is_exact,
123  "Minor dimension smaller than a vector need to be a power of "
124  "two fraction");
125 
126  static_assert(s_minor_dim_registers == 0 ||
127  (s_minor_dim_elements % s_elements_per_register == 0),
128  "Minor dimensions greater than a vector length must be an "
129  "integer number of vectors");
130 
131 
132  static constexpr camp::idx_t s_major_dim_per_register =
133  s_elements_per_register / s_minor_dim_elements;
134 
135  static constexpr camp::idx_t s_segbits =
137 
138 private:
139  template<typename IDX>
140  RAJA_INLINE RAJA_HOST_DEVICE constexpr static auto to_register(IDX row,
141  IDX col) -> IDX
142  {
143  return layout_type::is_row_major()
144  ? (row * IDX(COL_SIZE) + col) >> IDX(s_shift_per_register)
145  : (col * IDX(ROW_SIZE) + row) >> IDX(s_shift_per_register);
146  }
147 
148  template<typename IDX>
149  RAJA_INLINE RAJA_HOST_DEVICE constexpr static auto to_lane(IDX row, IDX col)
150  -> IDX
151  {
152  return layout_type::is_row_major()
153  ? (row * IDX(COL_SIZE) + col) & IDX(s_mask_per_register)
154  : (col * IDX(ROW_SIZE) + row) & IDX(s_mask_per_register);
155  }
156 
157  using base_type::m_registers;
158 
159 public:
161 
162  RAJA_INLINE
163  constexpr TensorRegister() : base_type() {}
164 
166 
167  RAJA_INLINE
168  TensorRegister(element_type c) : base_type(c) { this->broadcast(c); }
169 
170  RAJA_INLINE
171 
173  TensorRegister(self_type const& c) : base_type(c) { this->copy(c); }
174 
176 
177  RAJA_INLINE
179 
187  template<camp::idx_t STRIDE_ONE_DIM>
188  RAJA_HOST_DEVICE RAJA_INLINE static constexpr bool is_ref_packed()
189  {
190  return (STRIDE_ONE_DIM == 0 && layout_type::is_column_major()) ||
191  (STRIDE_ONE_DIM == 1 && layout_type::is_row_major());
192  }
193 
198 
199  RAJA_INLINE
200  static constexpr camp::idx_t s_dim_elem(camp::idx_t dim)
201  {
202  return dim == 0 ? ROW_SIZE : COL_SIZE;
203  }
204 
210 
211  RAJA_INLINE
213  {
214  this->broadcast(value);
215  return *this;
216  }
217 
219 
220  RAJA_INLINE
221  self_type& operator=(self_type const& c) { return this->copy(c); }
222 
226  template<typename T2, typename L, typename RP>
228  {
229  return matrix_multiply(y);
230  }
231 
236  template<typename T2, typename RP>
238  {
239  return right_multiply_vector(y);
240  }
241 
242 
243  template<typename REF_TYPE>
244  struct RefBridge;
245 
246  template<typename REF_TYPE>
247  RAJA_HOST_DEVICE RAJA_INLINE self_type& load_ref(REF_TYPE const& ref)
248  {
249  RefBridge<REF_TYPE>::load_ref(*this, ref);
250  return *this;
251  }
252 
253  template<typename REF_TYPE>
254  RAJA_HOST_DEVICE RAJA_INLINE self_type const& store_ref(REF_TYPE& ref) const
255  {
256  RefBridge<REF_TYPE>::store_ref(*this, ref);
257  return *this;
258  }
259 
260  template<typename POINTER_TYPE,
261  typename INDEX_TYPE,
263  camp::idx_t STRIDE_ONE_DIM>
264  struct RefBridge<
265  RAJA::internal::expt::
266  TensorRef<POINTER_TYPE, INDEX_TYPE, TENSOR_SIZE, 2, STRIDE_ONE_DIM>>
267  {
268 
269  using RefType = RAJA::internal::expt::
270  TensorRef<POINTER_TYPE, INDEX_TYPE, TENSOR_SIZE, 2, STRIDE_ONE_DIM>;
271 
275  RAJA_INLINE
276 
278  static void load_ref(self_type& self, RefType const& ref)
279  {
280 
281  auto ptr = ref.m_pointer + ref.m_tile.m_begin[0] * ref.m_stride[0] +
282  ref.m_tile.m_begin[1] * ref.m_stride[1];
283 
284  // check for packed data
285  if (self.is_ref_packed<STRIDE_ONE_DIM>())
286  {
287  // full vector?
288  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
289  {
290  self.load_packed(ptr, ref.m_stride[0], ref.m_stride[1]);
291  }
292  // partial
293  else
294  {
295  self.load_packed_nm(ptr, ref.m_stride[0], ref.m_stride[1],
296  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
297  }
298  }
299  // strided data
300  else
301  {
302  // full vector?
303  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
304  {
305  self.load_strided(ptr, ref.m_stride[0], ref.m_stride[1]);
306  }
307  // partial
308  else
309  {
310  self.load_strided_nm(ptr, ref.m_stride[0], ref.m_stride[1],
311  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
312  }
313  }
314  }
315 
319  RAJA_INLINE
320 
322  static void store_ref(self_type const& self, RefType& ref)
323  {
324 
325  auto ptr = ref.m_pointer + ref.m_tile.m_begin[0] * ref.m_stride[0] +
326  ref.m_tile.m_begin[1] * ref.m_stride[1];
327 
328  // check for packed data
329  if (self.is_ref_packed<STRIDE_ONE_DIM>())
330  {
331  // full vector?
332  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
333  {
334  self.store_packed(ptr, ref.m_stride[0], ref.m_stride[1]);
335  }
336  // partial
337  else
338  {
339  self.store_packed_nm(ptr, ref.m_stride[0], ref.m_stride[1],
340  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
341  }
342  }
343  // strided data
344  else
345  {
346  // full vector?
347  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
348  {
349  self.store_strided(ptr, ref.m_stride[0], ref.m_stride[1]);
350  }
351  // partial
352  else
353  {
354  self.store_strided_nm(ptr, ref.m_stride[0], ref.m_stride[1],
355  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
356  }
357  }
358  }
359  };
360 
361  template<typename POINTER_TYPE,
362  typename INDEX_TYPE,
364  INDEX_TYPE StrideInt1,
365  INDEX_TYPE StrideInt2,
366  INDEX_TYPE BeginInt1,
367  INDEX_TYPE BeginInt2,
368  INDEX_TYPE SizeInt1,
369  INDEX_TYPE SizeInt2,
370  camp::idx_t STRIDE_ONE_DIM>
372  POINTER_TYPE,
373  INDEX_TYPE,
374  TENSOR_SIZE,
375  camp::int_seq<INDEX_TYPE, StrideInt1, StrideInt2>,
376  camp::int_seq<INDEX_TYPE, BeginInt1, BeginInt2>,
377  camp::int_seq<INDEX_TYPE, SizeInt1, SizeInt2>,
378  STRIDE_ONE_DIM>>
379  {
380 
382  POINTER_TYPE,
383  INDEX_TYPE,
384  TENSOR_SIZE,
385  camp::int_seq<INDEX_TYPE, StrideInt1, StrideInt2>,
386  camp::int_seq<INDEX_TYPE, BeginInt1, BeginInt2>,
387  camp::int_seq<INDEX_TYPE, SizeInt1, SizeInt2>,
388  STRIDE_ONE_DIM>;
389 
393  RAJA_INLINE
394 
396  static void load_ref(self_type& self, RefType const& ref)
397  {
398 
399  auto ptr = ref.m_pointer + ref.m_tile.m_begin[0] * ref.m_stride[0] +
400  ref.m_tile.m_begin[1] * ref.m_stride[1];
401 
402  // check for packed data
403  if (self.is_ref_packed<STRIDE_ONE_DIM>())
404  {
405  // full vector?
406  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
407  {
408  self.load_packed(ptr, ref.m_stride[0], ref.m_stride[1]);
409  }
410  // partial
411  else
412  {
413  self.load_packed_nm(ptr, ref.m_stride[0], ref.m_stride[1],
414  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
415  }
416  }
417  // strided data
418  else
419  {
420  // full vector?
421  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
422  {
423  self.load_strided(ptr, ref.m_stride[0], ref.m_stride[1]);
424  }
425  // partial
426  else
427  {
428  self.load_strided_nm(ptr, ref.m_stride[0], ref.m_stride[1],
429  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
430  }
431  }
432  }
433 
437  RAJA_INLINE
438 
440  static void store_ref(self_type const& self, RefType& ref)
441  {
442 
443  auto ptr = ref.m_pointer + ref.m_tile.m_begin[0] * ref.m_stride[0] +
444  ref.m_tile.m_begin[1] * ref.m_stride[1];
445 
446  // check for packed data
447  if (self.is_ref_packed<STRIDE_ONE_DIM>())
448  {
449  // full vector?
450  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
451  {
452  self.store_packed(ptr, ref.m_stride[0], ref.m_stride[1]);
453  }
454  // partial
455  else
456  {
457  self.store_packed_nm(ptr, ref.m_stride[0], ref.m_stride[1],
458  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
459  }
460  }
461  // strided data
462  else
463  {
464  // full vector?
465  if (TENSOR_SIZE == RAJA::internal::expt::TENSOR_FULL)
466  {
467  self.store_strided(ptr, ref.m_stride[0], ref.m_stride[1]);
468  }
469  // partial
470  else
471  {
472  self.store_strided_nm(ptr, ref.m_stride[0], ref.m_stride[1],
473  ref.m_tile.m_size[0], ref.m_tile.m_size[1]);
474  }
475  }
476  }
477  };
478 
489 
490  RAJA_INLINE
492  int row_stride,
493  int col_stride)
494  {
495  // if it's dense in columns and rows, just do a dense load
496  if ((layout_type::is_row_major() && (row_stride == COL_SIZE)) ||
497  (layout_type::is_column_major() && (col_stride == ROW_SIZE)))
498  {
499 
500  for (camp::idx_t reg = 0; reg < s_num_registers; ++reg)
501  {
502  m_registers[reg].load_packed(ptr + reg * s_elements_per_register);
503  }
504  }
505  // Do semi-dense load for row-major
506  else if (layout_type::is_row_major())
507  {
508 
509  // one or more registers per column
510  if (s_minor_dim_registers)
511  {
512  camp::idx_t reg = 0;
513  for (camp::idx_t row = 0; row < ROW_SIZE; ++row)
514  {
515  for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
516  {
517 
518  camp::idx_t offset =
519  row * row_stride + colreg * s_elements_per_register;
520 
521  m_registers[reg].load_packed(ptr + offset);
522 
523  reg++;
524  }
525  }
526  }
527  // more than one column per register
528  else
529  {
530  // default to strided operation
531  return load_strided(ptr, row_stride, col_stride);
532  }
533  }
534  // Do semi-dense load for column-major
535  else
536  {
537  // one or more registers per row
538  if (s_minor_dim_registers)
539  {
540 
541  camp::idx_t reg = 0;
542  for (camp::idx_t col = 0; col < COL_SIZE; ++col)
543  {
544  for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
545  {
546 
547  camp::idx_t offset =
548  col * col_stride + rowreg * s_elements_per_register;
549 
550  m_registers[reg].load_packed(ptr + offset);
551 
552  reg++;
553  }
554  }
555  }
556  // more than one column per register
557  else
558  {
559  // default to strided operation
560  return load_strided(ptr, row_stride, col_stride);
561  }
562  }
563 
564  return *this;
565  }
566 
571 
572  RAJA_INLINE
574  int row_stride,
575  int col_stride)
576  {
577 
578  if (layout_type::is_row_major())
579  {
580  // one or more registers per row
581  if (s_minor_dim_registers)
582  {
583  for (camp::idx_t i = 0; i < s_num_registers; ++i)
584  {
585  camp::idx_t row =
586  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
587  camp::idx_t col =
588  s_elements_per_register * (i - (row * s_minor_dim_registers));
589  m_registers[i].load_strided(ptr + row * row_stride + col * col_stride,
590  col_stride);
591  }
592  }
593  // less than one register per row
594  else
595  {
596  for (camp::idx_t i = 0; i < s_num_registers; ++i)
597  {
598  element_type const* ptr_i =
599  ptr + i * row_stride * s_major_dim_per_register;
600  m_registers[i].segmented_load(ptr_i, s_segbits, col_stride,
601  row_stride);
602  }
603  }
604  }
605 
606  // column major
607  else
608  {
609 
610  // one or more registers per column
611  if (s_minor_dim_registers)
612  {
613  for (camp::idx_t i = 0; i < s_num_registers; ++i)
614  {
615  camp::idx_t col =
616  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
617  camp::idx_t row =
618  s_elements_per_register * (i - (col * s_minor_dim_registers));
619 
620  m_registers[i].load_strided(ptr + row * row_stride + col * col_stride,
621  row_stride);
622  }
623  }
624  // less than one register per column
625  else
626  {
627  for (camp::idx_t i = 0; i < s_num_registers; ++i)
628  {
629  element_type const* ptr_i =
630  ptr + i * col_stride * s_major_dim_per_register;
631  m_registers[i].segmented_load(ptr_i, s_segbits, row_stride,
632  col_stride);
633  }
634  }
635  }
636 
637  return *this;
638  }
639 
644 
645  RAJA_INLINE
647  int row_stride,
648  int col_stride,
649  int num_rows,
650  int num_cols)
651  {
652 
653  if (layout_type::is_row_major())
654  {
655 
656  // one or more registers per column
657  if (s_minor_dim_registers)
658  {
659 
660  for (camp::idx_t row = 0; row < num_rows; ++row)
661  {
662  for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
663  {
664 
665  camp::idx_t reg = row * s_minor_dim_registers + colreg;
666 
667  camp::idx_t col0 = colreg * s_elements_per_register;
668  camp::idx_t offset = row * row_stride + col0;
669 
670  // loading a complete register
671  if (col0 + s_elements_per_register <= num_cols)
672  {
673  m_registers[reg].load_packed(ptr + offset);
674  }
675 
676  // partial register at end of row
677  else
678  {
679  m_registers[reg].load_packed_n(ptr + offset, num_cols - col0);
680 
681  // zero out the remaining registers, if any
682  for (camp::idx_t i = colreg + 1; i < s_minor_dim_registers; ++i)
683  {
684  reg++;
685  m_registers[reg] = element_type(0);
686  }
687 
688  break; // end this row
689  }
690  }
691  }
692 
693  // zero out remaining rows
694  for (camp::idx_t row = num_rows; row < ROW_SIZE; ++row)
695  {
696  for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
697  {
698 
699  camp::idx_t reg = row * s_minor_dim_registers + colreg;
700 
701  m_registers[reg] = element_type(0);
702  }
703  }
704  }
705  // more than one column per register
706  else
707  {
708  // default to strided operation
709  return load_strided_nm(ptr, row_stride, col_stride, num_rows, num_cols);
710  }
711  }
712  // Do semi-dense load for column-major
713  else
714  {
715 
716  // one or more registers per column
717  if (s_minor_dim_registers)
718  {
719 
720  for (camp::idx_t col = 0; col < num_cols; ++col)
721  {
722  for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
723  {
724 
725  camp::idx_t reg = col * s_minor_dim_registers + rowreg;
726 
727  camp::idx_t row0 = rowreg * s_elements_per_register;
728  camp::idx_t offset = col * col_stride + row0;
729 
730  // loading a complete register
731  if (row0 + s_elements_per_register <= num_rows)
732  {
733  m_registers[reg].load_packed(ptr + offset);
734  }
735 
736  // partial register at end of column
737  else
738  {
739  m_registers[reg].load_packed_n(ptr + offset, num_rows - row0);
740 
741  // zero out the remaining registers, if any
742  for (camp::idx_t i = rowreg + 1; i < s_minor_dim_registers; ++i)
743  {
744  reg++;
745  m_registers[reg] = element_type(0);
746  }
747 
748  break; // end this column
749  }
750  }
751  }
752  // zero out remaining columns
753  for (camp::idx_t col = num_cols; col < COL_SIZE; ++col)
754  {
755  for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
756  {
757 
758  camp::idx_t reg = col * s_minor_dim_registers + rowreg;
759 
760  m_registers[reg] = element_type(0);
761  }
762  }
763  }
764  // more than one column per register
765  else
766  {
767 
768  // default to strided operation
769  return load_strided_nm(ptr, row_stride, col_stride, num_rows, num_cols);
770  }
771  }
772 
773  return *this;
774  }
775 
780 
781  RAJA_INLINE
783  int row_stride,
784  int col_stride,
785  int num_rows,
786  int num_cols)
787  {
788 
789  if (layout_type::is_row_major())
790  {
791  // one or more registers per row
792  if (s_minor_dim_registers)
793  {
794 
795  for (camp::idx_t i = 0; i < s_num_registers; ++i)
796  {
797  camp::idx_t row =
798  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
799  if (row >= num_rows)
800  {
801  m_registers[i] = element_type(0);
802  }
803  else
804  {
805  camp::idx_t col =
806  s_elements_per_register * (i - (row * s_minor_dim_registers));
807 
808 
809  camp::idx_t reg_num_cols = s_elements_per_register;
810  if (reg_num_cols + col > num_cols)
811  {
812  reg_num_cols = num_cols - col;
813  m_registers[i].load_strided_n(ptr + row * row_stride +
814  col * col_stride,
815  col_stride, reg_num_cols);
816  }
817  else
818  {
819  m_registers[i].load_strided(
820  ptr + row * row_stride + col * col_stride, col_stride);
821  }
822  }
823  }
824  }
825  // less than one register per row
826  else
827  {
828 
829  for (camp::idx_t i = 0; i < s_num_registers; ++i)
830  {
831  // figure out how many rows get loaded in this register
832  camp::idx_t reg_num_rows = num_rows - i * s_major_dim_per_register;
833  reg_num_rows = reg_num_rows > s_major_dim_per_register
834  ? s_major_dim_per_register
835  : reg_num_rows;
836 
837  element_type const* ptr_i =
838  ptr + i * row_stride * s_major_dim_per_register;
839  m_registers[i].segmented_load_nm(ptr_i, s_segbits, col_stride,
840  row_stride, num_cols, reg_num_rows);
841  }
842  }
843  }
844 
845  // column major
846  else
847  {
848 
849  // one or more registers per column
850  if (s_minor_dim_registers)
851  {
852  for (camp::idx_t i = 0; i < s_num_registers; ++i)
853  {
854  camp::idx_t col =
855  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
856  if (col >= num_cols)
857  {
858  m_registers[i] = element_type(0);
859  }
860  else
861  {
862  camp::idx_t row =
863  s_elements_per_register * (i - (col * s_minor_dim_registers));
864 
865  camp::idx_t reg_num_rows = s_elements_per_register;
866  if (reg_num_rows + row > num_rows)
867  {
868  reg_num_rows = num_rows - row;
869  m_registers[i].load_strided_n(ptr + row * row_stride +
870  col * col_stride,
871  row_stride, reg_num_rows);
872  }
873  else
874  {
875  m_registers[i].load_strided(
876  ptr + row * row_stride + col * col_stride, row_stride);
877  }
878  }
879  }
880  }
881  // less than one register per column
882  else
883  {
884  for (camp::idx_t i = 0; i < s_num_registers; ++i)
885  {
886  // figure out how many columns get loaded in this register
887  camp::idx_t reg_num_cols = num_cols - i * s_major_dim_per_register;
888  reg_num_cols = reg_num_cols > s_major_dim_per_register
889  ? s_major_dim_per_register
890  : reg_num_cols;
891 
892  element_type const* ptr_i =
893  ptr + i * col_stride * s_major_dim_per_register;
894  m_registers[i].segmented_load_nm(ptr_i, s_segbits, row_stride,
895  col_stride, num_rows, reg_num_cols);
896  }
897  }
898  }
899 
900  return *this;
901  }
902 
910 
911  RAJA_INLINE
913  int row_stride,
914  int col_stride) const
915  {
916 
917  // if it's dense in columns and rows, just do a dense load
918  if ((layout_type::is_row_major() && (row_stride == COL_SIZE)) ||
919  (layout_type::is_column_major() && (col_stride == ROW_SIZE)))
920  {
921 
922  for (camp::idx_t reg = 0; reg < s_num_registers; ++reg)
923  {
924  m_registers[reg].store_packed(ptr + reg * s_elements_per_register);
925  }
926  }
927  // Do semi-dense store for row-major
928  else if (layout_type::is_row_major())
929  {
930 
931  // one or more registers per column
932  if (s_minor_dim_registers)
933  {
934  for (camp::idx_t i = 0; i < s_num_registers; ++i)
935  {
936  camp::idx_t row =
937  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
938  camp::idx_t col =
939  s_elements_per_register * (i - (row * s_minor_dim_registers));
940  m_registers[i].store_packed(ptr + row * row_stride +
941  col * col_stride);
942  }
943  }
944  // more than one column per register
945  else
946  {
947  store_strided(ptr, row_stride, col_stride);
948  }
949  }
950  // Do semi-dense store for column-major
951  else
952  {
953  // one or more registers per row
954  if (s_minor_dim_registers)
955  {
956  for (camp::idx_t i = 0; i < s_num_registers; ++i)
957  {
958  camp::idx_t col =
959  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
960  camp::idx_t row =
961  s_elements_per_register * (i - (col * s_minor_dim_registers));
962  m_registers[i].store_packed(ptr + row * row_stride +
963  col * col_stride);
964  }
965  }
966  // more than one row per register
967  else
968  {
969  store_strided(ptr, row_stride, col_stride);
970  }
971  }
972 
973 
974  return *this;
975  }
976 
982 
983  RAJA_INLINE
985  int row_stride,
986  int col_stride) const
987  {
988 
989 
990  if (layout_type::is_row_major())
991  {
992  // one or more registers per row
993  if (s_minor_dim_registers)
994  {
995  for (camp::idx_t i = 0; i < s_num_registers; ++i)
996  {
997  camp::idx_t row =
998  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
999  camp::idx_t col =
1000  s_elements_per_register * (i - (row * s_minor_dim_registers));
1001  m_registers[i].store_strided(
1002  ptr + row * row_stride + col * col_stride, col_stride);
1003  }
1004  }
1005  // less than one register per row
1006  else
1007  {
1008  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1009  {
1010  element_type* ptr_i = ptr + i * row_stride * s_major_dim_per_register;
1011  m_registers[i].segmented_store(ptr_i, s_segbits, col_stride,
1012  row_stride);
1013  }
1014  }
1015  }
1016 
1017  // column major
1018  else
1019  {
1020  // one or more registers per column
1021  if (s_minor_dim_registers)
1022  {
1023  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1024  {
1025  camp::idx_t col =
1026  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1027  camp::idx_t row =
1028  s_elements_per_register * (i - (col * s_minor_dim_registers));
1029  m_registers[i].store_strided(
1030  ptr + row * row_stride + col * col_stride, row_stride);
1031  }
1032  }
1033  // less than one register per column
1034  else
1035  {
1036  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1037  {
1038  element_type* ptr_i = ptr + i * col_stride * s_major_dim_per_register;
1039  m_registers[i].segmented_store(ptr_i, s_segbits, row_stride,
1040  col_stride);
1041  }
1042  }
1043  }
1044 
1045  return *this;
1046  }
1047 
1052 
1053  RAJA_INLINE
1055  int row_stride,
1056  int col_stride,
1057  int num_rows,
1058  int num_cols) const
1059  {
1060 
1061 
1062  if (layout_type::is_row_major())
1063  {
1064 
1065  // one or more registers per column
1066  if (s_minor_dim_registers)
1067  {
1068 
1069  for (camp::idx_t row = 0; row < num_rows; ++row)
1070  {
1071  for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
1072  {
1073 
1074  camp::idx_t reg = row * s_minor_dim_registers + colreg;
1075 
1076  camp::idx_t col0 = colreg * s_elements_per_register;
1077  camp::idx_t offset = row * row_stride + col0;
1078 
1079  // store a complete register
1080  if (col0 + s_elements_per_register <= num_cols)
1081  {
1082  m_registers[reg].store_packed(ptr + offset);
1083  }
1084 
1085  // partial register at end of row
1086  else
1087  {
1088  m_registers[reg].store_packed_n(ptr + offset, num_cols - col0);
1089 
1090  break; // end this row
1091  }
1092  }
1093  }
1094  }
1095  // more than one column per register
1096  else
1097  {
1098  // default to strided operation
1099  return store_strided_nm(ptr, row_stride, col_stride, num_rows,
1100  num_cols);
1101  }
1102  }
1103  // Do semi-dense store for column-major
1104  else
1105  {
1106 
1107  // one or more registers per column
1108  if (s_minor_dim_registers)
1109  {
1110 
1111  for (camp::idx_t col = 0; col < num_cols; ++col)
1112  {
1113  for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
1114  {
1115 
1116  camp::idx_t reg = col * s_minor_dim_registers + rowreg;
1117 
1118  camp::idx_t row0 = rowreg * s_elements_per_register;
1119  camp::idx_t offset = col * col_stride + row0;
1120 
1121  // loading a complete register
1122  if (row0 + s_elements_per_register <= num_rows)
1123  {
1124  m_registers[reg].store_packed(ptr + offset);
1125  }
1126 
1127  // partial register at end of column
1128  else
1129  {
1130  m_registers[reg].store_packed_n(ptr + offset, num_rows - row0);
1131 
1132  break; // end this column
1133  }
1134  }
1135  }
1136  }
1137  // more than one column per register
1138  else
1139  {
1140 
1141  // default to strided operation
1142  return store_strided_nm(ptr, row_stride, col_stride, num_rows,
1143  num_cols);
1144  }
1145  }
1146 
1147  return *this;
1148  }
1149 
1154 
1155  RAJA_INLINE
1157  int row_stride,
1158  int col_stride,
1159  int num_rows,
1160  int num_cols) const
1161  {
1162 
1163 
1164  if (layout_type::is_row_major())
1165  {
1166  // one or more registers per row
1167  if (s_minor_dim_registers)
1168  {
1169 
1170  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1171  {
1172  camp::idx_t row =
1173  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1174  if (row < num_rows)
1175  {
1176  camp::idx_t col =
1177  s_elements_per_register * (i - (row * s_minor_dim_registers));
1178 
1179 
1180  camp::idx_t reg_num_cols = s_elements_per_register;
1181  if (reg_num_cols + col > num_cols)
1182  {
1183  reg_num_cols = num_cols - col;
1184  m_registers[i].store_strided_n(ptr + row * row_stride +
1185  col * col_stride,
1186  col_stride, reg_num_cols);
1187  }
1188  else
1189  {
1190  m_registers[i].store_strided(
1191  ptr + row * row_stride + col * col_stride, col_stride);
1192  }
1193  }
1194  }
1195  }
1196  // less than one register per row
1197  else
1198  {
1199 
1200  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1201  {
1202  // figure out how many rows get loaded in this register
1203  camp::idx_t reg_num_rows = num_rows - i * s_major_dim_per_register;
1204  reg_num_rows = reg_num_rows > s_major_dim_per_register
1205  ? s_major_dim_per_register
1206  : reg_num_rows;
1207 
1208  element_type* ptr_i = ptr + i * row_stride * s_major_dim_per_register;
1209  m_registers[i].segmented_store_nm(ptr_i, s_segbits, col_stride,
1210  row_stride, num_cols, reg_num_rows);
1211  }
1212  }
1213  }
1214 
1215  // column major
1216  else
1217  {
1218 
1219  // one or more registers per column
1220  if (s_minor_dim_registers)
1221  {
1222  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1223  {
1224  camp::idx_t col =
1225  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1226  if (col < num_cols)
1227  {
1228  camp::idx_t row =
1229  s_elements_per_register * (i - (col * s_minor_dim_registers));
1230 
1231  camp::idx_t reg_num_rows = s_elements_per_register;
1232  if (reg_num_rows + row > num_rows)
1233  {
1234  reg_num_rows = num_rows - row;
1235  m_registers[i].store_strided_n(ptr + row * row_stride +
1236  col * col_stride,
1237  row_stride, reg_num_rows);
1238  }
1239  else
1240  {
1241  m_registers[i].store_strided(
1242  ptr + row * row_stride + col * col_stride, row_stride);
1243  }
1244  }
1245  }
1246  }
1247  // less than one register per column
1248  else
1249  {
1250  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1251  {
1252  // figure out how many columns get loaded in this register
1253  camp::idx_t reg_num_cols = num_cols - i * s_major_dim_per_register;
1254  reg_num_cols = reg_num_cols > s_major_dim_per_register
1255  ? s_major_dim_per_register
1256  : reg_num_cols;
1257 
1258  element_type* ptr_i = ptr + i * col_stride * s_major_dim_per_register;
1259  m_registers[i].segmented_store_nm(ptr_i, s_segbits, row_stride,
1260  col_stride, num_rows, reg_num_cols);
1261  }
1262  }
1263  }
1264 
1265  return *this;
1266  }
1267 
1270 
1271  RAJA_INLINE
1272  self_type divide_nm(self_type const& mat, int num_rows, int num_cols) const
1273  {
1274  self_type result;
1275 
1276 
1277  if (layout_type::is_row_major())
1278  {
1279  // one or more registers per row
1280  if (s_minor_dim_registers)
1281  {
1282 
1283  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1284  {
1285  camp::idx_t row =
1286  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1287  if (row < num_rows)
1288  {
1289  camp::idx_t col =
1290  s_elements_per_register * (i - (row * s_minor_dim_registers));
1291 
1292 
1293  camp::idx_t reg_num_cols = s_elements_per_register;
1294  if (reg_num_cols + col > num_cols)
1295  {
1296  reg_num_cols = num_cols - col;
1297  result.m_registers[i] =
1298  m_registers[i].divide_n(mat.m_registers[i], reg_num_cols);
1299  }
1300  else
1301  {
1302  result.m_registers[i] = m_registers[i].divide(mat.m_registers[i]);
1303  }
1304  }
1305  }
1306  }
1307  // less than one register per row
1308  else
1309  {
1310 
1311  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1312  {
1313  // figure out how many rows get loaded in this register
1314  camp::idx_t reg_num_rows = num_rows - i * s_major_dim_per_register;
1315  reg_num_rows = reg_num_rows > s_major_dim_per_register
1316  ? s_major_dim_per_register
1317  : reg_num_rows;
1318 
1319  result.m_registers[i] = m_registers[i].segmented_divide_nm(
1320  mat.m_registers[i], s_segbits, num_cols, reg_num_rows);
1321  }
1322  }
1323  }
1324 
1325  // column major
1326  else
1327  {
1328 
1329  // one or more registers per column
1330  if (s_minor_dim_registers)
1331  {
1332  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1333  {
1334  camp::idx_t col =
1335  i / (s_minor_dim_registers ? s_minor_dim_registers : 1);
1336  if (col < num_cols)
1337  {
1338  camp::idx_t row =
1339  s_elements_per_register * (i - (col * s_minor_dim_registers));
1340 
1341  camp::idx_t reg_num_rows = s_elements_per_register;
1342  if (reg_num_rows + row > num_rows)
1343  {
1344  reg_num_rows = num_rows - row;
1345  result.m_registers[i] =
1346  m_registers[i].divide_n(mat.m_registers[i], reg_num_rows);
1347  }
1348  else
1349  {
1350  result.m_registers[i] = m_registers[i].divide(mat.m_registers[i]);
1351  }
1352  }
1353  }
1354  }
1355  // less than one register per column
1356  else
1357  {
1358  for (camp::idx_t i = 0; i < s_num_registers; ++i)
1359  {
1360  // figure out how many columns get loaded in this register
1361  camp::idx_t reg_num_cols = num_cols - i * s_major_dim_per_register;
1362  reg_num_cols = reg_num_cols > s_major_dim_per_register
1363  ? s_major_dim_per_register
1364  : reg_num_cols;
1365 
1366  result.m_registers[i] = m_registers[i].segmented_divide_nm(
1367  mat.m_registers[i], s_segbits, num_rows, reg_num_cols);
1368  }
1369  }
1370  }
1371 
1372 
1373  return result;
1374  }
1375 
1381 #if 0
1383  RAJA_INLINE
1384  transpose_type transpose() const {
1385 
1386  static constexpr camp::idx_t num_elem = register_type::s_num_elem;
1387 
1388  /*
1389  * We use Eklundh's Algorithm: Recursive block transpose because
1390  * it's easy to implement using SIMD register permutation primitives
1391  *
1392  * Executes in n*log(n) row operations
1393  *
1394  * Also, the algorithm is the same for row and column major.
1395  */
1396  self_type result = *this;
1397  // 1 register is split over multiple rows
1398  if(s_minor_dim_registers == 0){
1399 // for(camp::idx_t lvl = 0; (1<<lvl) < num_elem;++ lvl){
1400 // // At this level, we do block transposes of NxN sub-matrices, where
1401 // // N = 1<<lvl
1402 //
1403 // auto const &vals = result.m_registers;
1404 //
1405 // self_type tmp;
1406 // for(camp::idx_t i = 0;i < s_num_registers;++ i){
1407 // if(((i>>lvl)&0x1) == 0){
1408 // tmp.m_registers[i] = vals[i - (i&(1<<lvl))].transpose_shuffle_left(lvl, vals[i - (i&(1<<lvl)) + (1<<lvl)]);
1409 // }
1410 // else{
1411 // tmp.m_registers[i] = vals[i - (i&(1<<lvl))].transpose_shuffle_right(lvl, vals[i - (i&(1<<lvl)) + (1<<lvl)]);
1412 // }
1413 // }
1414 // result = tmp;
1415 // }
1416  }
1417  // one or more registers per row/column
1418  else{
1419 
1420 
1421  // This only works with square matrices.... need to generalize
1422  for(camp::idx_t lvl = 0; (1<<lvl) < num_elem;++ lvl){
1423  // At this level, we do block transposes of NxN sub-matrices, where
1424  // N = 1<<lvl
1425 
1426  camp::idx_t skip_bits = 0;
1427  if(transpose_type::s_major_dim_per_register <= 1){
1428  skip_bits = lvl;
1429  }
1430  camp::idx_t skip_reg = (1<<skip_bits)*s_minor_dim_registers;
1431 
1432  auto const &vals = result.m_registers;
1433 
1434  self_type tmp;
1435  for(camp::idx_t major = 0;major < s_major_dim_elements;++ major){
1436  if(((major>>skip_bits)&0x1) == 0){
1437  for(camp::idx_t i = major*s_minor_dim_registers;i < (major+1)*s_minor_dim_registers;++ i){
1438  tmp.m_registers[i] = vals[i].transpose_shuffle_left(lvl, vals[i+skip_reg]);
1439  }
1440 
1441  }
1442  else{
1443  for(camp::idx_t i = major*s_minor_dim_registers;i < (major+1)*s_minor_dim_registers;++ i){
1444 
1445  tmp.m_registers[i] = vals[i-skip_reg].transpose_shuffle_right(lvl, vals[i]);
1446  }
1447  }
1448  }
1449  result = tmp;
1450 
1451  }
1452 
1453 
1454  // Now do the same Eklhund algorithm on registers, which is needed
1455  // if we have more than one register per input minor dim
1456  for(camp::idx_t lvl = 0; (1<<lvl) < s_minor_dim_registers;++ lvl){
1457 
1458 
1459  camp::idx_t skip_reg = 1<<lvl;
1460 
1461  auto const &vals = result.m_registers;
1462 
1463  self_type tmp;
1464  for(camp::idx_t major = 0;major < s_major_dim_elements;++ major){
1465  if(((major>>skip_bits)&0x1) == 0){
1466  for(camp::idx_t minor = 0;minor < self_type::s_minor_dim_registers;++ minor){
1467 
1468  // extract value x or y
1469  camp::idx_t xy_select = (minor >> lvl) & 0x1;
1470 
1471  camp::idx_t reg = major*s_minor_dim_registers + minor;
1472  camp::idx_t reg_x = major*s_minor_dim_registers + minor;
1473  camp::idx_t reg_y = (major+skip_reg)*s_minor_dim_registers + minor;
1474 
1475 
1476  tmp.m_registers[reg] =
1477  xy_select == 0 ? result.m_registers[reg_x] : result.m_registers[reg_y];
1478 
1479  }
1480  }
1481  else{
1482 
1483  }
1484  }
1485  result = tmp;
1486 
1487  }
1488 
1489  }
1490 
1491  transpose_type *tptr = reinterpret_cast<transpose_type*>(&result);
1492 
1493 
1494 
1495 
1496  return *tptr;
1497  }
1498 
1499 
1506  RAJA_INLINE
1507  void inplace_transpose() {
1508  *this = transpose();
1509  }
1510 
1520  RAJA_INLINE
1521  transpose_tensor_type const &transpose_by_type() const {
1522  return reinterpret_cast<transpose_tensor_type const &>(*this);
1523  }
1524 #endif
1529 
1530  RAJA_INLINE
1532  {
1533  column_vector_type result(0);
1534  return right_multiply_vector_accumulate(v, result);
1535  }
1536 
1541 
1542  RAJA_INLINE
1544  {
1545  row_vector_type result(0);
1546  return left_multiply_vector_accumulate(v, result);
1547  }
1548 
1556 
1557  RAJA_INLINE
1559  row_vector_type const& v,
1560  column_vector_type result) const
1561  {
1562 
1563  if (layout_type::is_row_major())
1564  {
1565 
1566  // 1 register is split over multiple rows
1567  if (s_minor_dim_registers == 0)
1568  {
1569 
1570  // start by broadcasting the first segment in v across all of v
1571  // we will use this term for all registers in the matrix
1572  auto vv = v.get_register(0).segmented_broadcast_inner(s_segbits, 0);
1573 
1574  // loop over output segments, which is also the number of
1575  // registers in the matrix (no kidding!)
1576  RAJA_UNROLL
1577  for (camp::idx_t outseg = 0; outseg < s_num_registers; ++outseg)
1578  {
1579 
1580  // compute which result register we are accumulating into
1581  camp::idx_t result_reg = outseg >> s_segbits;
1582 
1583  // compute which segment within result_reg we are accumulating into
1584  camp::idx_t result_seg = outseg - (result_reg << s_segbits);
1585 
1586  // compute segmented dot product to get output segment
1587  auto value =
1588  m_registers[outseg].segmented_dot(s_segbits, result_seg, vv);
1589 
1590  // accumulate result
1591  result.get_register(result_reg) += value;
1592  }
1593  }
1594  // one or more registers per row
1595  else
1596  {
1597 
1598  // Loop over rows
1599  camp::idx_t reg = 0;
1600  RAJA_UNROLL
1601  for (camp::idx_t row = 0; row < s_num_rows; ++row)
1602  {
1603 
1604  // compute partial dot products for all registers in this row
1605  auto rowsum = register_type(0);
1606  RAJA_UNROLL
1607  for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
1608  {
1609 
1610  rowsum =
1611  m_registers[reg].multiply_add(v.get_register(colreg), rowsum);
1612  reg++;
1613 
1614  } // rowreg
1615 
1616  // finish dot product by taking sum of rowsum
1617  auto value = result.get(row) + rowsum.sum();
1618  result.set(value, row);
1619 
1620  } // row
1621  }
1622  }
1623  else
1624  {
1625 
1626 
1627  // 1 register is split over multiple columns
1628  if (s_minor_dim_registers == 0)
1629  {
1630 
1631  auto& mv = result.get_register(0);
1632 
1633  // Loop over registers, which are also the segments in v
1634  RAJA_UNROLL
1635  for (camp::idx_t m_reg = 0; m_reg < s_num_registers; ++m_reg)
1636  {
1637  camp::idx_t v_reg = m_reg >> s_segbits;
1638  camp::idx_t v_seg = m_reg & ((1 << s_segbits) - 1);
1639 
1640  auto v_tmp =
1641  v.get_register(v_reg).segmented_broadcast_outer(s_segbits, v_seg);
1642  mv = m_registers[m_reg].multiply_add(v_tmp, mv);
1643  }
1644 
1645  // Now sum segments in mv together to form final result
1646  mv = mv.segmented_sum_outer(s_segbits, 0);
1647  }
1648  // one or more registers per column
1649  else
1650  {
1651 
1652  // Loop over columns (which is also registers)
1653  camp::idx_t reg = 0;
1654  RAJA_UNROLL
1655  for (camp::idx_t col = 0; col < s_num_columns; ++col)
1656  {
1657 
1658  // extract column value from v
1659  auto v_col = register_type(v.get(col));
1660 
1661  // apply v_col to entire column (1 or more registers)
1662  RAJA_UNROLL
1663  for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
1664  {
1665 
1666  auto& mv = result.get_register(rowreg);
1667  mv = m_registers[reg].multiply_add(v_col, mv);
1668 
1669  reg++;
1670 
1671  } // rowreg
1672  } // col
1673  }
1674  }
1675  return result;
1676  }
1677 
1685 
1686  RAJA_INLINE
1688  row_vector_type result) const
1689  {
1690 
1691  if (layout_type::is_row_major())
1692  {
1693 
1694  // 1 register is split over multiple columns
1695  if (s_minor_dim_registers == 0)
1696  {
1697  auto& vm = result.get_register(0);
1698 
1699  // Loop over registers, which are also the segments in v
1700  RAJA_UNROLL
1701  for (camp::idx_t m_reg = 0; m_reg < s_num_registers; ++m_reg)
1702  {
1703  camp::idx_t v_reg = m_reg >> s_segbits;
1704  camp::idx_t v_seg = m_reg & ((1 << s_segbits) - 1);
1705 
1706  auto v_tmp =
1707  v.get_register(v_reg).segmented_broadcast_outer(s_segbits, v_seg);
1708  vm = m_registers[m_reg].multiply_add(v_tmp, vm);
1709  }
1710 
1711  // Now sum segments in mv together to form final result
1712  vm = vm.segmented_sum_outer(s_segbits, 0);
1713  }
1714  // one or more registers per row
1715  else
1716  {
1717 
1718  // Loop over rows
1719  camp::idx_t reg = 0;
1720  RAJA_UNROLL
1721  for (camp::idx_t row = 0; row < s_num_rows; ++row)
1722  {
1723  auto lhs_bcat = register_type(v.get(row));
1724  RAJA_UNROLL
1725  for (camp::idx_t colreg = 0; colreg < s_minor_dim_registers; ++colreg)
1726  {
1727 
1728  result.get_register(colreg) = m_registers[reg].multiply_add(
1729  lhs_bcat, result.get_register(colreg));
1730  reg++;
1731 
1732  } // rowreg
1733  }
1734  }
1735 
1736 
1737  } // row-major
1738 
1739  // Column-major:
1740  else
1741  {
1742  // 1 register is split over multiple rows
1743  if (s_minor_dim_registers == 0)
1744  {
1745 
1746  // start by broadcasting the first segment in v across all of v
1747  // we will use this term for all registers in the matrix
1748  auto vv = v.get_register(0).segmented_broadcast_inner(s_segbits, 0);
1749 
1750  // loop over output segments, which is also the number of
1751  // registers in the matrix (no kidding!)
1752  RAJA_UNROLL
1753  for (camp::idx_t outseg = 0; outseg < s_num_registers; ++outseg)
1754  {
1755 
1756  // compute which result register we are accumulating into
1757  camp::idx_t result_reg = outseg >> s_segbits;
1758 
1759  // compute which segment within result_reg we are accumulating into
1760  camp::idx_t result_seg = outseg - (result_reg << s_segbits);
1761 
1762  // compute segmented dot product to get output segment
1763  auto value =
1764  m_registers[outseg].segmented_dot(s_segbits, result_seg, vv);
1765 
1766  // accumulate result
1767  result.get_register(result_reg) += value;
1768  }
1769  }
1770  // one or more registers per column
1771  else
1772  {
1773  // Loop over rows
1774  camp::idx_t reg = 0;
1775  RAJA_UNROLL
1776  for (camp::idx_t col = 0; col < s_num_columns; ++col)
1777  {
1778 
1779  // compute partial dot products for all registers in this row
1780  auto colsum = register_type(0);
1781  RAJA_UNROLL
1782  for (camp::idx_t rowreg = 0; rowreg < s_minor_dim_registers; ++rowreg)
1783  {
1784  colsum =
1785  m_registers[reg].multiply_add(v.get_register(rowreg), colsum);
1786  reg++;
1787 
1788  } // rowreg
1789 
1790  // finish dot product by taking sum of rowsum
1791  auto value = result.get(col) + colsum.sum();
1792  result.set(value, col);
1793 
1794  } // col
1795  }
1796 
1797 
1798  } // col-major
1799  return result;
1800  }
1801 
1805  template<typename RMAT>
1806  RAJA_HOST_DEVICE RAJA_INLINE typename RAJA::internal::expt::
1807  MatrixMatrixMultiplyHelper<self_type, RMAT>::result_type
1808  matrix_multiply(RMAT const& mat) const
1809  {
1811  self_type, RMAT>::result_type res(0);
1813  *this, mat, res);
1814  return res;
1815  }
1816 
1820  template<typename RMAT>
1821  RAJA_HOST_DEVICE RAJA_INLINE typename RAJA::internal::expt::
1822  MatrixMatrixMultiplyHelper<self_type, RMAT>::result_type
1824  RMAT const& B,
1826  self_type,
1827  RMAT>::result_type const& C) const
1828  {
1830  self_type, RMAT>::result_type res(C);
1832  self_type, RMAT>::multiply_accumulate(*this, B, res);
1833  return res;
1834  }
1835 
1839  template<typename ACCMAT, typename RMAT>
1841  ACCMAT& acc,
1842  RMAT const& B) const
1843  {
1845  self_type, RMAT>::multiply_accumulate(*this, B, acc);
1846  }
1847 
1849 
1850  RAJA_INLINE
1851  self_type& set(element_type val, int row, int col)
1852  {
1853  m_registers[to_register(row, col)].set(val, to_lane(row, col));
1854  return *this;
1855  }
1856 
1858 
1859  RAJA_INLINE
1860  element_type get(int row, int col) const
1861  {
1862  return m_registers[to_register(row, col)].get(to_lane(row, col));
1863  }
1864 
1866 
1867  RAJA_INLINE
1868  register_type extract_diagonal_register(camp::idx_t starting_column,
1869  camp::idx_t segbits,
1870  camp::idx_t segment) const
1871  {
1872 
1873  register_type result(0);
1874 
1875  camp::idx_t num_rows = register_type::s_num_elem >> segbits;
1876  camp::idx_t num_repeats = 1 << segbits;
1877 
1878  camp::idx_t col0 = (starting_column + num_rows * segment) % s_num_columns;
1879  camp::idx_t row0 = num_rows * segment;
1880 
1881  for (camp::idx_t i = 0; i < num_rows; ++i)
1882  {
1883  camp::idx_t col = (col0 + i) % s_num_columns;
1884  camp::idx_t row = row0 + i;
1885  auto value = get(row, col);
1886  for (camp::idx_t j = 0; j < num_repeats; ++j)
1887  {
1888  result.set(value, (i << segbits) + j);
1889  }
1890  }
1891 
1892  return result;
1893  }
1894 
1900  RAJA_INLINE
1901  std::string to_string(bool one_line = false) const
1902  {
1903  std::string s = "Matrix(" + std::to_string(s_num_rows) + "x" +
1904  std::to_string(s_num_columns);
1905  if (!one_line)
1906  {
1907  s += ")\n";
1908  }
1909 
1910 
1911  s += "[ ";
1912 
1913  //
1914  for (camp::idx_t r = 0; r < s_num_rows; ++r)
1915  {
1916  if (r > 0)
1917  {
1918  s += ", ";
1919  if (!one_line)
1920  {
1921  s += "\n ";
1922  }
1923  }
1924  s += "[";
1925  for (camp::idx_t c = 0; c < s_num_columns; ++c)
1926  {
1927  if (c > 0)
1928  {
1929  s += ", ";
1930  }
1931  s += std::to_string(this->get(r, c));
1932  }
1933  s += "]";
1934  }
1935 
1936  s += " ]";
1937  if (!one_line)
1938  {
1939  s += "\n";
1940  }
1941  return s;
1942  }
1943 
1944 }; // MatrixRegisterImpl
1945 
1946 
1947 } // namespace expt
1948 } // namespace RAJA
1949 
1950 
1951 #endif
RAJA header file defining a bit masking operator.
RAJA header file defining SIMD/SIMT register operations.
RAJA header file defining SIMD/SIMT register operations.
Definition: RegisterBase.hpp:39
RAJA_HOST_DEVICE RAJA_INLINE RAJA::internal::expt::MatrixMatrixMultiplyHelper< self_type, RMAT >::result_type matrix_multiply_add(RMAT const &B, typename RAJA::internal::expt::MatrixMatrixMultiplyHelper< self_type, RMAT >::result_type const &C) const
Definition: MatrixRegisterImpl.hpp:1823
RAJA_HOST_DEVICE constexpr RAJA_INLINE TensorRegister()
Definition: MatrixRegisterImpl.hpp:163
RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(element_type value)
Set entire vector to a single scalar value.
Definition: MatrixRegisterImpl.hpp:212
RAJA_INLINE std::string to_string(bool one_line=false) const
Converts to matrix to a string.
Definition: MatrixRegisterImpl.hpp:1901
RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_ref(REF_TYPE &ref) const
Definition: MatrixRegisterImpl.hpp:254
RAJA_HOST_DEVICE RAJA_INLINE register_type extract_diagonal_register(camp::idx_t starting_column, camp::idx_t segbits, camp::idx_t segment) const
Definition: MatrixRegisterImpl.hpp:1868
RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_strided_nm(element_type *ptr, int row_stride, int col_stride, int num_rows, int num_cols) const
Definition: MatrixRegisterImpl.hpp:1156
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_packed(element_type *ptr, int row_stride, int col_stride) const
Definition: MatrixRegisterImpl.hpp:912
self_type operator*(SquareMatrixRegister< T2, L, RP > const &y) const
Definition: MatrixRegisterImpl.hpp:227
RAJA_HOST_DEVICE RAJA_INLINE self_type & set(element_type val, int row, int col)
Definition: MatrixRegisterImpl.hpp:1851
RAJA_HOST_DEVICE RAJA_INLINE RAJA::internal::expt::MatrixMatrixMultiplyHelper< self_type, RMAT >::result_type matrix_multiply(RMAT const &mat) const
Definition: MatrixRegisterImpl.hpp:1808
VectorRegister< T2, RP > operator*(VectorRegister< T2, RP > const &y) const
Definition: MatrixRegisterImpl.hpp:237
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_packed_nm(element_type const *ptr, int row_stride, int col_stride, int num_rows, int num_cols)
Definition: MatrixRegisterImpl.hpp:646
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE column_vector_type right_multiply_vector_accumulate(row_vector_type const &v, column_vector_type result) const
Definition: MatrixRegisterImpl.hpp:1558
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_strided(element_type const *ptr, int row_stride, int col_stride)
Definition: MatrixRegisterImpl.hpp:573
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type divide_nm(self_type const &mat, int num_rows, int num_cols) const
Definition: MatrixRegisterImpl.hpp:1272
RAJA_HOST_DEVICE static constexpr RAJA_INLINE bool is_ref_packed()
Definition: MatrixRegisterImpl.hpp:188
RAJA_HOST_DEVICE RAJA_INLINE column_vector_type right_multiply_vector(row_vector_type v) const
Definition: MatrixRegisterImpl.hpp:1531
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_ref(REF_TYPE const &ref)
Definition: MatrixRegisterImpl.hpp:247
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_packed(element_type const *ptr, int row_stride, int col_stride)
Definition: MatrixRegisterImpl.hpp:491
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_strided(element_type *ptr, int row_stride, int col_stride) const
Definition: MatrixRegisterImpl.hpp:984
RAJA_HOST_DEVICE RAJA_INLINE row_vector_type left_multiply_vector(column_vector_type v) const
Definition: MatrixRegisterImpl.hpp:1543
RAJA_HOST_DEVICE RAJA_INLINE self_type & operator=(self_type const &c)
Definition: MatrixRegisterImpl.hpp:221
RAJA_HOST_DEVICE RAJA_INLINE self_type const & store_packed_nm(element_type *ptr, int row_stride, int col_stride, int num_rows, int num_cols) const
Definition: MatrixRegisterImpl.hpp:1054
RAJA_HOST_DEVICE static constexpr RAJA_INLINE camp::idx_t s_dim_elem(camp::idx_t dim)
Definition: MatrixRegisterImpl.hpp:200
RAJA_SUPPRESS_HD_WARN RAJA_HOST_DEVICE RAJA_INLINE row_vector_type left_multiply_vector_accumulate(column_vector_type const &v, row_vector_type result) const
Definition: MatrixRegisterImpl.hpp:1687
RAJA_HOST_DEVICE RAJA_INLINE element_type get(int row, int col) const
Definition: MatrixRegisterImpl.hpp:1860
RAJA_HOST_DEVICE RAJA_INLINE void matrix_multiply_accumulate(ACCMAT &acc, RMAT const &B) const
Definition: MatrixRegisterImpl.hpp:1840
RAJA_HOST_DEVICE RAJA_INLINE self_type & load_strided_nm(element_type const *ptr, int row_stride, int col_stride, int num_rows, int num_cols)
Definition: MatrixRegisterImpl.hpp:782
RAJA_HOST_DEVICE RAJA_INLINE TensorRegister(element_type c)
Definition: MatrixRegisterImpl.hpp:168
RAJA_INLINE RAJA_HOST_DEVICE TensorRegister(self_type const &c)
Definition: MatrixRegisterImpl.hpp:173
Definition: TensorRegister.hpp:46
Definition: TensorRegisterBase.hpp:105
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
#define RAJA_SUPPRESS_HD_WARN
Definition: macros.hpp:68
TensorTileSize
Definition: TensorRef.hpp:234
@ TENSOR_FULL
Definition: TensorRef.hpp:236
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE constexpr RAJA_INLINE RAJA::zip_tuple_element_t< I, zip_tuple< is_val, Ts... > > & get(zip_tuple< is_val, Ts... > &z) noexcept
Definition: zip_tuple.hpp:56
Definition: BitMask.hpp:30
Definition: TensorLayout.hpp:35
RAJA_INLINE static RAJA_HOST_DEVICE void store_ref(self_type const &self, RefType &ref)
Performs load specified by TensorRef object.
Definition: MatrixRegisterImpl.hpp:322
RAJA_INLINE static RAJA_HOST_DEVICE void load_ref(self_type &self, RefType const &ref)
Performs load specified by TensorRef object.
Definition: MatrixRegisterImpl.hpp:278
Definition: MatrixMatrixMultiply.hpp:36
Definition: TensorRef.hpp:472
Definition: TensorRef.hpp:426
index_type m_stride[NUM_DIMS]
Definition: TensorRef.hpp:442
pointer_type m_pointer
Definition: TensorRef.hpp:441
tile_type m_tile
Definition: TensorRef.hpp:443
index_type m_begin[NUM_DIMS]
Definition: TensorRef.hpp:246
index_type m_size[NUM_DIMS]
Definition: TensorRef.hpp:247