Morpheus 1.0.0
Dynamic matrix type and algorithms for sparse matrices
Loading...
Searching...
No Matches
DenseVector/Kokkos/Morpheus_WAXPBY_Impl.hpp
1
24#ifndef MORPHEUS_DENSEVECTOR_KOKKOS_WAXPBY_IMPL_HPP
25#define MORPHEUS_DENSEVECTOR_KOKKOS_WAXPBY_IMPL_HPP
26
27#include <Morpheus_SpaceTraits.hpp>
28#include <Morpheus_FormatTraits.hpp>
29#include <Morpheus_FormatTags.hpp>
30#include <Morpheus_Spaces.hpp>
31
32#include <cassert>
33
34namespace Morpheus {
35namespace Impl {
36
37template <typename ExecSpace, typename Vector1, typename Vector2,
38 typename Vector3>
39inline void waxpby(
40 const typename Vector1::size_type n,
41 const typename Vector1::value_type alpha, const Vector1& x,
42 const typename Vector2::value_type beta, const Vector2& y, Vector3& w,
43 typename std::enable_if_t<
44 Morpheus::is_dense_vector_format_container_v<Vector1> &&
45 Morpheus::is_dense_vector_format_container_v<Vector2> &&
46 Morpheus::is_dense_vector_format_container_v<Vector3> &&
47 Morpheus::has_generic_backend_v<ExecSpace> &&
48 Morpheus::has_access_v<ExecSpace, Vector1, Vector2, Vector3>>* =
49 nullptr) {
50 using execution_space = typename ExecSpace::execution_space;
51 using size_type = typename Vector1::size_type;
52 using range_policy = Kokkos::RangePolicy<size_type, execution_space>;
53
54 assert(x.size() >= n);
55 assert(y.size() >= n);
56 assert(w.size() >= n);
57
58 const typename Vector1::value_array_type x_view = x.const_view();
59 const typename Vector2::value_array_type y_view = y.const_view();
60 typename Vector3::value_array_type w_view = w.view();
61 range_policy policy(0, n);
62
63 if (alpha == 1.0) {
64 Kokkos::parallel_for(
65 "waxpby_alpha", policy, KOKKOS_LAMBDA(const size_type& i) {
66 w_view[i] = x_view[i] + beta * y_view[i];
67 });
68 } else if (beta == 1.0) {
69 Kokkos::parallel_for(
70 "waxpby_beta", policy, KOKKOS_LAMBDA(const size_type& i) {
71 w_view[i] = alpha * x_view[i] + y_view[i];
72 });
73 } else {
74 Kokkos::parallel_for(
75 "waxpby", policy, KOKKOS_LAMBDA(const size_type& i) {
76 w_view[i] = alpha * x_view[i] + beta * y_view[i];
77 });
78 }
79}
80
81} // namespace Impl
82} // namespace Morpheus
83
84#endif // MORPHEUS_DENSEVECTOR_KOKKOS_WAXPBY_IMPL_HPP
Generic Morpheus interfaces.
Definition: dummy.cpp:24