21 #ifndef RAJA_pattern_tensor_ET_MultiplyOperator_HPP
22 #define RAJA_pattern_tensor_ET_MultiplyOperator_HPP
31 class TensorBlockConcreteBase;
44 template<
typename LEFT_OPERAND_TYPE,
45 typename RIGHT_OPERAND_TYPE,
51 static constexpr camp::idx_t
s_num_dims = LEFT_OPERAND_TYPE::s_num_dims;
59 (
int)RIGHT_OPERAND_TYPE::s_num_dims);
66 LEFT_OPERAND_TYPE
const& left,
67 RIGHT_OPERAND_TYPE
const& right)
69 return dim == 0 ? left.getDimSize(0) : right.getDimSize(1);
75 template<
typename TILE_TYPE>
77 TILE_TYPE
const&
tile,
78 LEFT_OPERAND_TYPE
const& left,
79 RIGHT_OPERAND_TYPE
const& right)
80 -> decltype(left.eval(
tile) * right.eval(
tile))
82 return left.eval(
tile) * right.eval(
tile);
88 template<
typename TILE_TYPE,
typename ADD_OPERAND_TYPE>
90 TILE_TYPE
const&
tile,
91 LEFT_OPERAND_TYPE
const& left,
92 RIGHT_OPERAND_TYPE
const& right,
93 ADD_OPERAND_TYPE
const& add)
94 -> decltype(left.eval(
tile).multiply_add(right.eval(
tile),
97 return left.eval(
tile).multiply_add(right.eval(
tile), add.eval(
tile));
103 template<
typename TILE_TYPE,
typename SUBTRACT_OPERAND_TYPE>
105 TILE_TYPE
const&
tile,
106 LEFT_OPERAND_TYPE
const& left,
107 RIGHT_OPERAND_TYPE
const& right,
108 SUBTRACT_OPERAND_TYPE
const& subtract)
109 -> decltype(left.eval(
tile).multiply_subtract(right.eval(
tile),
110 subtract.eval(
tile)))
112 return left.eval(
tile).multiply_subtract(right.eval(
tile),
113 subtract.eval(
tile));
120 template<
typename LEFT_OPERAND_TYPE,
typename RIGHT_OPERAND_TYPE>
124 typename
std::enable_if<LEFT_OPERAND_TYPE::s_num_dims == 0>::type>
128 static constexpr camp::idx_t
s_num_dims = RIGHT_OPERAND_TYPE::s_num_dims;
139 LEFT_OPERAND_TYPE
const&,
140 RIGHT_OPERAND_TYPE
const& right)
142 return right.getDimSize(dim);
148 template<
typename TILE_TYPE>
150 TILE_TYPE
const&
tile,
151 LEFT_OPERAND_TYPE
const& left,
152 RIGHT_OPERAND_TYPE
const& right)
153 -> decltype(right.eval(
tile).scale(left.eval(
tile)))
155 return right.eval(
tile).scale(left.eval(
tile));
161 template<
typename TILE_TYPE,
typename ADD_OPERAND_TYPE>
163 TILE_TYPE
const&
tile,
164 LEFT_OPERAND_TYPE
const& left,
165 RIGHT_OPERAND_TYPE
const& right,
166 ADD_OPERAND_TYPE
const& add)
167 -> decltype(right.eval(
tile).scale(left.eval(
tile)) + add.eval(
tile))
169 return right.eval(
tile).scale(left.eval(
tile)) + add.eval(
tile);
175 template<
typename TILE_TYPE,
typename SUBTRACT_OPERAND_TYPE>
177 TILE_TYPE
const&
tile,
178 LEFT_OPERAND_TYPE
const& left,
179 RIGHT_OPERAND_TYPE
const& right,
180 SUBTRACT_OPERAND_TYPE
const& subtract)
181 -> decltype(right.eval(
tile).scale(left.eval(
tile)) - subtract.eval(
tile))
183 return right.eval(
tile).scale(left.eval(
tile)) - subtract.eval(
tile);
190 template<
typename LEFT_OPERAND_TYPE,
typename RIGHT_OPERAND_TYPE>
194 typename
std::enable_if<RIGHT_OPERAND_TYPE::s_num_dims == 0>::type>
198 static constexpr camp::idx_t
s_num_dims = LEFT_OPERAND_TYPE::s_num_dims;
209 LEFT_OPERAND_TYPE
const& left,
210 RIGHT_OPERAND_TYPE
const&)
212 return left.getDimSize(dim);
218 template<
typename TILE_TYPE>
220 TILE_TYPE
const&
tile,
221 LEFT_OPERAND_TYPE
const& left,
222 RIGHT_OPERAND_TYPE
const& right)
223 -> decltype(left.eval(
tile).scale(right.eval(
tile)))
225 return left.eval(
tile).scale(right.eval(
tile));
231 template<
typename TILE_TYPE,
typename ADD_OPERAND_TYPE>
233 TILE_TYPE
const&
tile,
234 LEFT_OPERAND_TYPE
const& left,
235 RIGHT_OPERAND_TYPE
const& right,
236 ADD_OPERAND_TYPE
const& add)
237 -> decltype(left.eval(
tile).scale(right.eval(
tile)) + add.eval(
tile))
239 return left.eval(
tile).scale(right.eval(
tile)) + add.eval(
tile);
245 template<
typename TILE_TYPE,
typename SUBTRACT_OPERAND_TYPE>
247 TILE_TYPE
const&
tile,
248 LEFT_OPERAND_TYPE
const& left,
249 RIGHT_OPERAND_TYPE
const& right,
250 SUBTRACT_OPERAND_TYPE
const& subtract)
251 -> decltype(left.eval(
tile).scale(right.eval(
tile)) - subtract.eval(
tile))
253 return left.eval(
tile).scale(right.eval(
tile)) - subtract.eval(
tile);
269 template<
typename LEFT_OPERAND_TYPE,
typename RIGHT_OPERAND_TYPE>
273 typename
std::enable_if<LEFT_OPERAND_TYPE::s_num_dims == 2 &&
274 RIGHT_OPERAND_TYPE::s_num_dims == 1>::type>
280 typename LEFT_OPERAND_TYPE::result_type::column_vector_type;
292 LEFT_OPERAND_TYPE
const&,
293 RIGHT_OPERAND_TYPE
const& right)
295 return dim == 0 ? right.getDimSize(0) : 0;
301 template<
typename TILE_TYPE>
303 TILE_TYPE
const&
tile,
304 LEFT_OPERAND_TYPE
const& left,
305 RIGHT_OPERAND_TYPE
const& right)
312 multiply_into_result(result,
tile, left, right);
317 template<
typename TILE_TYPE,
typename ADD_TYPE>
319 TILE_TYPE
const&
tile,
320 LEFT_OPERAND_TYPE
const& left,
321 RIGHT_OPERAND_TYPE
const& right,
329 multiply_into_result(result,
tile, left, right);
335 template<
typename STORAGE,
typename TILE_TYPE,
typename INDEX =
void>
336 struct MultiplyBridge;
338 template<
typename STORAGE,
typename TILE_TYPE>
341 TILE_TYPE
const&
tile,
342 LEFT_OPERAND_TYPE
const& et_left,
343 RIGHT_OPERAND_TYPE
const& et_right)
348 auto tile_size = left_type::result_type::s_dim_elem(1);
349 auto k_size = et_left.getDimSize(1);
356 LEFT_OPERAND_TYPE::result_type::s_get_default_tile().nonstatic();
357 left_tile.m_begin[0] =
tile.m_begin[0];
358 left_tile.m_size[0] =
tile.m_size[0];
359 left_tile.m_size[1] = tile_size;
361 using RightType =
typename TILE_TYPE::nonstatic_self_type;
363 RightType right_tile =
tile;
364 right_tile.m_size[0] = tile_size;
367 decltype(k_size) k = 0;
368 for (; k + tile_size <= k_size; k += tile_size)
372 left_tile.m_begin[1] = k;
373 auto left = et_left.eval(left_tile);
375 right_tile.m_begin[0] = k;
376 auto right = et_right.eval(right_tile);
379 result = left.right_multiply_vector_accumulate(right, result);
385 left_part_tile.m_begin[1] = k;
386 left_part_tile.m_size[1] = k_size - k;
387 auto left = et_left.eval(left_part_tile);
390 right_part_tile.m_begin[0] = k;
391 right_part_tile.m_size[0] = k_size - k;
392 auto right = et_right.eval(right_part_tile);
395 result = left.right_multiply_vector_accumulate(right, result);
402 static_assert(!std::is_same<T, void>::value,
"diag");
405 template<
typename I, TensorTileSize TTS,
typename B,
typename S>
406 struct Diag<StaticTensorTile<I, TTS, B, S>>
408 static_assert(std::is_same<I, void>::value,
"diag");
411 template<
typename STORAGE,
typename TILE_TYPE,
typename INDEX>
412 struct MultiplyBridge
415 Diag<TILE_TYPE> diag;
420 static void multiply_into_result(STORAGE& result,
421 TILE_TYPE
const&
tile,
422 LEFT_OPERAND_TYPE
const& et_left,
423 RIGHT_OPERAND_TYPE
const& et_right)
428 auto tile_size = left_type::result_type::s_dim_elem(1);
429 auto k_size = et_left.getDimSize(1);
436 LEFT_OPERAND_TYPE::result_type::s_get_default_tile().nonstatic();
437 left_tile.m_begin[0] =
tile.m_begin[0];
438 left_tile.m_size[0] =
tile.m_size[0];
439 left_tile.m_size[1] = tile_size;
441 using RightType =
typename TILE_TYPE::nonstatic_self_type;
443 RightType right_tile =
tile;
444 right_tile.m_size[0] = tile_size;
447 decltype(k_size) k = 0;
448 for (; k + tile_size <= k_size; k += tile_size)
452 left_tile.m_begin[1] = k;
453 auto left = et_left.eval(left_tile);
455 right_tile.m_begin[0] = k;
456 auto right = et_right.eval(right_tile);
459 result = left.right_multiply_vector_accumulate(right, result);
465 left_part_tile.m_begin[1] = k;
466 left_part_tile.m_size[1] = k_size - k;
467 auto left = et_left.eval(left_part_tile);
470 right_part_tile.m_begin[0] = k;
471 right_part_tile.m_size[0] = k_size - k;
472 auto right = et_right.eval(right_part_tile);
475 result = left.right_multiply_vector_accumulate(right, result);
480 template<
size_t INDEX,
485 INDEX_TYPE... BeginTail,
487 INDEX_TYPE... SizeTail>
488 struct MultiplyBridge<
490 StaticTensorTile<INDEX_TYPE,
492 camp::int_seq<INDEX_TYPE, Begin0, BeginTail...>,
493 camp::int_seq<INDEX_TYPE, Size0, SizeTail...>>,
494 camp::integral_constant<size_t, INDEX>>
498 StaticTensorTile<INDEX_TYPE,
500 camp::int_seq<INDEX_TYPE, Begin0, BeginTail...>,
501 camp::int_seq<INDEX_TYPE, Size0, SizeTail...>>;
506 static void multiply_into_result(STORAGE& result,
507 TileType
const&
tile,
508 LEFT_OPERAND_TYPE
const& et_left,
509 RIGHT_OPERAND_TYPE
const& et_right)
513 const auto tile_size = left_type::result_type::s_dim_elem(1);
514 const auto k_size = et_left.getDimSize(1);
516 auto const offset = INDEX * tile_size;
518 if ((offset + tile_size) <= k_size)
522 StaticTensorTile<INDEX_TYPE, TENSOR_SIZE,
523 camp::int_seq<INDEX_TYPE, Begin0, offset>,
524 camp::int_seq<INDEX_TYPE, Size0, tile_size>>;
526 auto left = et_left.eval(LeftType());
529 StaticTensorTile<INDEX_TYPE, TENSOR_SIZE,
530 camp::int_seq<INDEX_TYPE, offset>,
531 camp::int_seq<INDEX_TYPE, tile_size>>;
533 auto right = et_right.eval(RightType());
536 auto temp = left.right_multiply_vector_accumulate(right, result);
537 MultiplyBridge<STORAGE, TileType,
538 camp::integral_constant<size_t, INDEX - 1>>::
539 multiply_into_result(result,
tile, et_left, et_right);
547 camp::int_seq<INDEX_TYPE, Begin0, offset>,
548 camp::int_seq<INDEX_TYPE, Size0, k_size - offset>>;
549 auto left = et_left.eval(LeftType());
553 camp::int_seq<INDEX_TYPE, offset>,
554 camp::int_seq<INDEX_TYPE, k_size - offset>>;
555 auto right = et_right.eval(RightType());
558 result = left.right_multiply_vector_accumulate(right, result);
563 template<
typename STORAGE,
567 INDEX_TYPE... BeginTail,
569 INDEX_TYPE... SizeTail>
570 struct MultiplyBridge<
572 StaticTensorTile<INDEX_TYPE,
574 camp::int_seq<INDEX_TYPE, Begin0, BeginTail...>,
575 camp::int_seq<INDEX_TYPE, Size0, SizeTail...>>,
576 camp::integral_constant<size_t, 0>>
580 StaticTensorTile<INDEX_TYPE,
582 camp::int_seq<INDEX_TYPE, Begin0, BeginTail...>,
583 camp::int_seq<INDEX_TYPE, Size0, SizeTail...>>;
588 static void multiply_into_result(STORAGE& result,
590 LEFT_OPERAND_TYPE
const& et_left,
591 RIGHT_OPERAND_TYPE
const& et_right)
595 const auto tile_size = left_type::result_type::s_dim_elem(1);
596 const auto k_size = et_left.getDimSize(1);
598 auto const offset = 0;
600 if ((offset + tile_size) <= k_size)
604 StaticTensorTile<INDEX_TYPE, TENSOR_SIZE,
605 camp::int_seq<INDEX_TYPE, Begin0, offset>,
606 camp::int_seq<INDEX_TYPE, Size0, tile_size>>;
608 auto left = et_left.eval(LeftType());
611 StaticTensorTile<INDEX_TYPE, TENSOR_SIZE,
612 camp::int_seq<INDEX_TYPE, offset>,
613 camp::int_seq<INDEX_TYPE, tile_size>>;
615 auto right = et_right.eval(RightType());
618 auto temp = left.right_multiply_vector_accumulate(right, result);
626 camp::int_seq<INDEX_TYPE, Begin0, offset>,
627 camp::int_seq<INDEX_TYPE, Size0, k_size - offset>>;
628 auto left = et_left.eval(LeftType());
632 camp::int_seq<INDEX_TYPE, offset>,
633 camp::int_seq<INDEX_TYPE, k_size - offset>>;
634 auto right = et_right.eval(RightType());
637 result = left.right_multiply_vector_accumulate(right, result);
642 template<
typename STORAGE,
646 INDEX_TYPE... BeginTail,
648 INDEX_TYPE... SizeTail>
649 struct MultiplyBridge<
651 StaticTensorTile<INDEX_TYPE,
653 camp::int_seq<INDEX_TYPE, Begin0, BeginTail...>,
654 camp::int_seq<INDEX_TYPE, Size0, SizeTail...>>,
659 StaticTensorTile<INDEX_TYPE,
661 camp::int_seq<INDEX_TYPE, Begin0, BeginTail...>,
662 camp::int_seq<INDEX_TYPE, Size0, SizeTail...>>;
667 static void multiply_into_result(STORAGE& result,
668 TileType
const&
tile,
669 LEFT_OPERAND_TYPE
const& et_left,
670 RIGHT_OPERAND_TYPE
const& et_right)
673 const auto tile_size = left_type::result_type::s_dim_elem(1);
674 const auto k_size = et_left.getDimSize(1);
675 const size_t iter_count =
676 (k_size / tile_size) + ((k_size % tile_size != 0) ? 1 : 0);
678 MultiplyBridge<STORAGE, TileType,
679 camp::integral_constant<size_t, iter_count>>::
680 multiply_into_result(result,
tile, et_left, et_right);
686 template<
typename LEFT_OPERAND_TYPE,
687 typename RIGHT_OPERAND_TYPE,
688 typename ADD_OPERAND_TYPE>
689 class TensorMultiplyAdd;
703 template<
typename LEFT_OPERAND_TYPE,
typename RIGHT_OPERAND_TYPE>
707 typename
std::enable_if<LEFT_OPERAND_TYPE::s_num_dims == 1 &&
708 RIGHT_OPERAND_TYPE::s_num_dims == 2>::type>
713 using result_type =
typename RIGHT_OPERAND_TYPE::result_type::row_vector_type;
725 LEFT_OPERAND_TYPE
const& left,
726 RIGHT_OPERAND_TYPE
const&)
728 return dim == 0 ? left.getDimSize(0) : 0;
734 template<
typename TILE_TYPE>
736 TILE_TYPE
const&
tile,
737 LEFT_OPERAND_TYPE
const& left,
738 RIGHT_OPERAND_TYPE
const& right)
744 multiply_into_result(result,
tile, left, right);
749 template<
typename TILE_TYPE,
typename ADD_TYPE>
751 TILE_TYPE
const&
tile,
752 LEFT_OPERAND_TYPE
const& left,
753 RIGHT_OPERAND_TYPE
const& right,
760 multiply_into_result(result,
tile, left, right);
766 template<
typename STORAGE,
typename TILE_TYPE>
769 TILE_TYPE
const&
tile,
770 LEFT_OPERAND_TYPE
const& et_left,
771 RIGHT_OPERAND_TYPE
const& et_right)
774 auto tile_size = right_type::result_type::s_dim_elem(0);
775 auto k_size = et_right.getDimSize(0);
784 RIGHT_OPERAND_TYPE::result_type::s_get_default_tile().nonstatic();
785 right_tile.m_begin[1] =
tile.m_begin[0];
786 right_tile.m_size[1] =
tile.m_size[0];
787 right_tile.m_size[0] = tile_size;
789 TILE_TYPE left_tile =
tile;
790 left_tile.m_size[0] = tile_size;
794 decltype(k_size) k = 0;
795 for (; k + tile_size <= k_size; k += tile_size)
799 right_tile.m_begin[0] = k;
800 auto right = et_right.eval(right_tile);
802 left_tile.m_begin[0] = k;
803 auto left = et_left.eval(left_tile);
806 result = right.left_multiply_vector_accumulate(left, result);
812 right_part_tile.m_begin[0] = k;
813 right_part_tile.m_size[0] = k_size - k;
814 auto right = et_right.eval(right_part_tile);
817 left_part_tile.m_begin[0] = k;
818 left_part_tile.m_size[0] = k_size - k;
819 auto left = et_left.eval(left_part_tile);
822 result = right.left_multiply_vector_accumulate(left, result);
834 template<
typename LEFT_OPERAND_TYPE,
typename RIGHT_OPERAND_TYPE>
838 typename
std::enable_if<LEFT_OPERAND_TYPE::s_num_dims == 2 &&
839 RIGHT_OPERAND_TYPE::s_num_dims == 2>::type>
844 using result_type =
typename LEFT_OPERAND_TYPE::result_type::product_type;
856 LEFT_OPERAND_TYPE
const& left,
857 RIGHT_OPERAND_TYPE
const& right)
859 return dim == 0 ? left.getDimSize(0) : right.getDimSize(1);
865 template<
typename TILE_TYPE>
867 TILE_TYPE
const&
tile,
868 LEFT_OPERAND_TYPE
const& left,
869 RIGHT_OPERAND_TYPE
const& right)
892 multiply_into_result(result,
tile, left, right);
897 template<
typename TILE_TYPE,
typename ADD_TYPE>
899 TILE_TYPE
const&
tile,
900 LEFT_OPERAND_TYPE
const& left,
901 RIGHT_OPERAND_TYPE
const& right,
908 multiply_into_result(result,
tile, left, right);
914 template<
typename STORAGE,
typename TILE_TYPE>
917 TILE_TYPE
const&
tile,
918 LEFT_OPERAND_TYPE
const& et_left,
919 RIGHT_OPERAND_TYPE
const& et_right)
922 using right_tensor_type =
typename right_type::result_type;
923 auto tile_size = right_tensor_type::s_dim_elem(0);
924 auto k_size = et_left.getDimSize(1);
931 TILE_TYPE left_tile =
tile;
932 left_tile.m_size[1] = tile_size;
933 auto left_begin = et_left.getDimBegin(1);
935 TILE_TYPE right_tile =
tile;
936 right_tile.m_size[0] = tile_size;
937 auto right_begin = et_right.getDimBegin(0);
941 decltype(k_size) k = 0;
942 for (; k + tile_size <= k_size; k += tile_size)
946 left_tile.m_begin[1] = k + left_begin;
947 auto left = et_left.eval(left_tile);
949 right_tile.m_begin[0] = k + right_begin;
950 auto right = et_right.eval(right_tile);
953 left.matrix_multiply_accumulate(result, right);
960 left_part_tile.m_begin[1] = k + left_begin;
961 left_part_tile.m_size[1] = k_size - k;
962 auto left = et_left.eval(left_part_tile);
965 right_part_tile.m_begin[0] = k + right_begin;
966 right_part_tile.m_size[0] = k_size - k;
967 auto right = et_right.eval(right_part_tile);
970 left.matrix_multiply_accumulate(result, right);
975 template<
typename OPERAND_TYPE,
typename TILE_TYPE>
985 static constexpr camp::idx_t
s_num_dims = OPERAND_TYPE::s_num_dims;
996 : m_operand {operand},
1005 return m_tile.m_size[dim];
1013 return m_tile.m_begin[dim];
1016 template<
typename TILE_TYPE2>
1018 -> decltype(m_operand.eval(
tile))
1020 return m_operand.eval(
tile);
1028 printf(
"RestrictExtents(");
1029 m_operand.print_ast();
1034 template<
typename OPERAND,
typename TILE>
1038 using tile_type =
typename OPERAND::tile_type;
1040 new_tile.copy(
tile);
1052 template<
typename LEFT_OPERAND_TYPE,
typename RIGHT_OPERAND_TYPE>
1056 typename
std::enable_if<
1057 std::is_base_of<TensorBlockConcreteBase,
1058 typename RIGHT_OPERAND_TYPE::tensor_type>::value &&
1059 LEFT_OPERAND_TYPE::s_num_dims == 2 &&
1060 RIGHT_OPERAND_TYPE::s_num_dims == 2>::type>
1064 using result_type =
typename LEFT_OPERAND_TYPE::result_type::product_type;
1092 LEFT_OPERAND_TYPE
const& left,
1093 RIGHT_OPERAND_TYPE
const& right)
1095 return dim == 0 ? left.getDimSize(0) : right.getDimSize(1);
1101 template<
typename TILE_TYPE>
1103 TILE_TYPE
const&
tile,
1104 LEFT_OPERAND_TYPE
const&,
1105 RIGHT_OPERAND_TYPE
const&)
1136 template<
typename TILE_TYPE,
typename ADD_TYPE>
1138 TILE_TYPE
const&
tile,
1139 LEFT_OPERAND_TYPE
const& left,
1140 RIGHT_OPERAND_TYPE
const& right,
1162 block_tile_type block_tile;
1163 block_tile.copy(
tile);
1171 result_et = add.eval(
tile);
1179 multiply_into_result(result_et,
tile, left, right);
1186 template<
typename STORAGE,
typename TILE_TYPE>
1189 TILE_TYPE
const&
tile,
1190 LEFT_OPERAND_TYPE
const& et_left,
1191 RIGHT_OPERAND_TYPE
const& et_right)
1195 auto tile_size = result_type::s_dim_elem(1);
1196 auto k_size = et_left.getDimSize(1);
1203 TILE_TYPE left_tile =
tile;
1204 left_tile.m_size[1] = tile_size;
1205 auto left_begin = et_left.getDimBegin(1);
1207 TILE_TYPE right_tile =
tile;
1208 right_tile.m_size[0] = tile_size;
1209 auto right_begin = et_right.getDimBegin(0);
1213 decltype(k_size) k = 0;
1214 for (; k + tile_size <= k_size; k += tile_size)
1219 left_tile.m_begin[1] = k + left_begin;
1220 auto left = et_left.eval(left_tile);
1222 right_tile.m_begin[0] = k + right_begin;
1223 auto right = et_right.eval(right_tile);
1235 left_part_tile.m_begin[1] = k + left_begin;
1236 left_part_tile.m_size[1] = k_size - k;
1237 auto left = et_left.eval(left_part_tile);
1240 right_part_tile.m_begin[0] = k + right_begin;
1241 right_part_tile.m_size[0] = k_size - k;
1242 auto right = et_right.eval(right_part_tile);
Definition: BlockLiteral.hpp:51
typename STORAGE_TYPE::ref_type ref_type
Definition: BlockLiteral.hpp:57
typename ref_type::tile_type tile_type
Definition: BlockLiteral.hpp:58
RAJA_INLINE RAJA_HOST_DEVICE ref_type get_ref()
Definition: BlockLiteral.hpp:102
Definition: MultiplyOperator.hpp:978
RAJA_INLINE constexpr RAJA_HOST_DEVICE index_type getDimBegin(camp::idx_t dim) const
Definition: MultiplyOperator.hpp:1011
typename OPERAND_TYPE::result_type result_type
Definition: MultiplyOperator.hpp:982
typename TILE_TYPE::index_type index_type
Definition: MultiplyOperator.hpp:983
RAJA_INLINE RAJA_HOST_DEVICE RestrictExtents(operand_type const &operand, tile_type const &tile)
Definition: MultiplyOperator.hpp:995
TILE_TYPE tile_type
Definition: MultiplyOperator.hpp:984
RAJA_INLINE RAJA_HOST_DEVICE void print_ast() const
Definition: MultiplyOperator.hpp:1026
RAJA_INLINE RAJA_HOST_DEVICE auto eval(TILE_TYPE2 const &tile) const -> decltype(m_operand.eval(tile))
Definition: MultiplyOperator.hpp:1017
OPERAND_TYPE operand_type
Definition: MultiplyOperator.hpp:981
static constexpr camp::idx_t s_num_dims
Definition: MultiplyOperator.hpp:985
RAJA_INLINE constexpr RAJA_HOST_DEVICE index_type getDimSize(index_type dim) const
Definition: MultiplyOperator.hpp:1003
Definition: ExpressionTemplateBase.hpp:72
Definition: TensorLoadStore.hpp:75
RAJA_SUPPRESS_HD_WARN RAJA_INLINE RAJA_HOST_DEVICE auto eval(TILE_TYPE const &tile) const -> decltype(tensor_type::s_load_ref(merge_ref_tile(m_ref, tile)))
Definition: TensorLoadStore.hpp:166
#define RAJA_HOST_DEVICE
Definition: macros.hpp:65
RestrictExtents< OPERAND, TILE > restrictExtents(OPERAND const &operand, TILE const &tile)
Definition: MultiplyOperator.hpp:1035
TensorTileSize
Definition: TensorRef.hpp:234
@ TENSOR_PARTIAL
Definition: TensorRef.hpp:235
RAJA_INLINE constexpr RAJA_HOST_DEVICE TensorTile< INDEX_TYPE, TENSOR_PARTIAL, NUM_DIMS > & make_tensor_tile_partial(TensorTile< INDEX_TYPE, RTENSOR_SIZE, NUM_DIMS > &tile)
Definition: TensorRef.hpp:733
Definition: AlignedRangeIndexSetBuilders.cpp:35
RAJA_HOST_DEVICE RAJA_INLINE void tile(CONTEXT const &ctx, TILE_T tile_size, SEGMENT const &segment, BODY const &body)
Definition: launch_core.hpp:589
Definition: ListSegment.hpp:416
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:1086
LEFT_OPERAND_TYPE left_type
Definition: MultiplyOperator.hpp:1062
RAJA_INLINE static RAJA_HOST_DEVICE block_literal multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &, RIGHT_OPERAND_TYPE const &)
Definition: MultiplyOperator.hpp:1102
typename tensor_type::storage_type storage_type
Definition: MultiplyOperator.hpp:1076
typename RIGHT_OPERAND_TYPE::tensor_type tensor_type
Definition: MultiplyOperator.hpp:1073
typename LEFT_OPERAND_TYPE::result_type::product_type result_type
Definition: MultiplyOperator.hpp:1064
RAJA_INLINE static RAJA_HOST_DEVICE block_literal multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_TYPE const &add)
Definition: MultiplyOperator.hpp:1137
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:1091
RIGHT_OPERAND_TYPE right_type
Definition: MultiplyOperator.hpp:1063
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:719
RAJA_INLINE static RAJA_HOST_DEVICE result_type multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_TYPE const &add)
Definition: MultiplyOperator.hpp:750
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &)
Definition: MultiplyOperator.hpp:724
typename RIGHT_OPERAND_TYPE::result_type::row_vector_type result_type
Definition: MultiplyOperator.hpp:713
RAJA_INLINE static RAJA_HOST_DEVICE result_type multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:735
LEFT_OPERAND_TYPE left_type
Definition: MultiplyOperator.hpp:711
RIGHT_OPERAND_TYPE right_type
Definition: MultiplyOperator.hpp:712
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right) -> decltype(left.eval(tile).scale(right.eval(tile)))
Definition: MultiplyOperator.hpp:219
typename LEFT_OPERAND_TYPE::result_type result_type
Definition: MultiplyOperator.hpp:197
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:203
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &)
Definition: MultiplyOperator.hpp:208
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_OPERAND_TYPE const &add) -> decltype(left.eval(tile).scale(right.eval(tile))+add.eval(tile))
Definition: MultiplyOperator.hpp:232
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_subtract(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, SUBTRACT_OPERAND_TYPE const &subtract) -> decltype(left.eval(tile).scale(right.eval(tile)) - subtract.eval(tile))
Definition: MultiplyOperator.hpp:246
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right) -> decltype(right.eval(tile).scale(left.eval(tile)))
Definition: MultiplyOperator.hpp:149
typename RIGHT_OPERAND_TYPE::result_type result_type
Definition: MultiplyOperator.hpp:127
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:138
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_OPERAND_TYPE const &add) -> decltype(right.eval(tile).scale(left.eval(tile))+add.eval(tile))
Definition: MultiplyOperator.hpp:162
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_subtract(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, SUBTRACT_OPERAND_TYPE const &subtract) -> decltype(right.eval(tile).scale(left.eval(tile)) - subtract.eval(tile))
Definition: MultiplyOperator.hpp:176
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:133
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:850
LEFT_OPERAND_TYPE left_type
Definition: MultiplyOperator.hpp:842
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:855
RAJA_INLINE static RAJA_HOST_DEVICE result_type multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_TYPE const &add)
Definition: MultiplyOperator.hpp:898
typename LEFT_OPERAND_TYPE::result_type::product_type result_type
Definition: MultiplyOperator.hpp:844
RIGHT_OPERAND_TYPE right_type
Definition: MultiplyOperator.hpp:843
RAJA_INLINE static RAJA_HOST_DEVICE result_type multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:866
RAJA_INLINE static RAJA_HOST_DEVICE result_type multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_TYPE const &add)
Definition: MultiplyOperator.hpp:318
LEFT_OPERAND_TYPE left_type
Definition: MultiplyOperator.hpp:277
RAJA_INLINE static RAJA_HOST_DEVICE result_type multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:302
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:286
typename LEFT_OPERAND_TYPE::result_type::column_vector_type result_type
Definition: MultiplyOperator.hpp:280
RIGHT_OPERAND_TYPE right_type
Definition: MultiplyOperator.hpp:278
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:291
Definition: MultiplyOperator.hpp:48
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right) -> decltype(left.eval(tile) *right.eval(tile))
Definition: MultiplyOperator.hpp:76
static constexpr camp::idx_t s_num_dims
Definition: MultiplyOperator.hpp:51
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_subtract(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, SUBTRACT_OPERAND_TYPE const &subtract) -> decltype(left.eval(tile).multiply_subtract(right.eval(tile), subtract.eval(tile)))
Definition: MultiplyOperator.hpp:104
typename LEFT_OPERAND_TYPE::result_type result_type
Definition: MultiplyOperator.hpp:50
RAJA_INLINE static RAJA_HOST_DEVICE void print_ast()
Definition: MultiplyOperator.hpp:56
RAJA_INLINE static RAJA_HOST_DEVICE auto multiply_add(TILE_TYPE const &tile, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right, ADD_OPERAND_TYPE const &add) -> decltype(left.eval(tile).multiply_add(right.eval(tile), add.eval(tile)))
Definition: MultiplyOperator.hpp:89
RAJA_INLINE static RAJA_HOST_DEVICE int getDimSize(int dim, LEFT_OPERAND_TYPE const &left, RIGHT_OPERAND_TYPE const &right)
Definition: MultiplyOperator.hpp:65