Morpheus 1.0.0
Dynamic matrix type and algorithms for sparse matrices
Loading...
Searching...
No Matches
DenseVector/Cuda/Morpheus_WAXPBY_Impl.hpp
1
24#ifndef MORPHEUS_DENSEVECTOR_CUDA_WAXPBY_IMPL_HPP
25#define MORPHEUS_DENSEVECTOR_CUDA_WAXPBY_IMPL_HPP
26
27#include <Morpheus_Macros.hpp>
28#if defined(MORPHEUS_ENABLE_CUDA)
29
30#include <Morpheus_SpaceTraits.hpp>
31#include <Morpheus_FormatTraits.hpp>
32#include <Morpheus_FormatTags.hpp>
33#include <Morpheus_Spaces.hpp>
34
35#include <impl/Morpheus_CudaUtils.hpp>
36#include <impl/DenseVector/Kernels/Morpheus_WAXPBY_Impl.hpp>
37
38#include <cassert>
39
40namespace Morpheus {
41namespace Impl {
42
43template <typename ExecSpace, typename SizeType, typename Vector1,
44 typename Vector2, typename Vector3>
45inline void waxpby(
46 const SizeType n, const typename Vector1::value_type alpha,
47 const Vector1& x, const typename Vector2::value_type beta, const Vector2& y,
48 Vector3& w,
49 typename std::enable_if_t<
50 Morpheus::is_dense_vector_format_container_v<Vector1> &&
51 Morpheus::is_dense_vector_format_container_v<Vector2> &&
52 Morpheus::is_dense_vector_format_container_v<Vector3> &&
53 Morpheus::has_custom_backend_v<ExecSpace> &&
54 Morpheus::has_cuda_execution_space_v<ExecSpace> &&
55 Morpheus::has_access_v<ExecSpace, Vector1, Vector2, Vector3>>* =
56 nullptr) {
57 assert(x.size() >= n);
58 assert(y.size() >= n);
59 assert(w.size() >= n);
60
61 const SizeType BLOCK_SIZE = 256;
62 const SizeType NUM_BLOCKS = Impl::ceil_div(n, BLOCK_SIZE);
63
64 Kernels::waxpby_kernel<SizeType, typename Vector1::value_type,
65 typename Vector2::value_type,
66 typename Vector3::value_type>
67 <<<NUM_BLOCKS, BLOCK_SIZE, 0>>>(n, alpha, x.data(), beta, y.data(),
68 w.data());
69#if defined(DEBUG) || defined(MORPHEUS_DEBUG)
70 getLastCudaError("spmv_waxpby_kernel: Kernel execution failed");
71#endif
72}
73
74} // namespace Impl
75} // namespace Morpheus
76
77#endif // MORPHEUS_ENABLE_CUDA
78#endif // MORPHEUS_DENSEVECTOR_CUDA_ELEMENTWISE_IMPL_HPP
Generic Morpheus interfaces.
Definition: dummy.cpp:24