24#ifndef MORPHEUS_CSR_HIP_MULTIPLY_IMPL_HPP
25#define MORPHEUS_CSR_HIP_MULTIPLY_IMPL_HPP
27#include <Morpheus_Macros.hpp>
28#if defined(MORPHEUS_ENABLE_HIP)
30#include <Morpheus_SpaceTraits.hpp>
31#include <Morpheus_FormatTraits.hpp>
32#include <Morpheus_FormatTags.hpp>
33#include <Morpheus_Spaces.hpp>
35#include <impl/Morpheus_HIPUtils.hpp>
36#include <impl/Csr/Kernels/Morpheus_Multiply_Impl.hpp>
42template <
typename Matrix,
typename Vector>
43void __spmv_csr_vector(
const Matrix& A,
const Vector& x, Vector& y,
46template <
typename Matrix,
typename Vector>
47void __spmv_csr_scalar(
const Matrix& A,
const Vector& x, Vector& y,
50template <
typename ExecSpace,
typename Matrix,
typename Vector>
52 const Matrix& A,
const Vector& x, Vector& y,
const bool init,
53 typename std::enable_if_t<
54 Morpheus::is_csr_matrix_format_container_v<Matrix> &&
55 Morpheus::is_dense_vector_format_container_v<Vector> &&
56 Morpheus::has_custom_backend_v<ExecSpace> &&
57 Morpheus::has_hip_execution_space_v<ExecSpace> &&
58 Morpheus::has_access_v<ExecSpace, Matrix, Vector>>* =
nullptr) {
59 switch (A.options()) {
60 case MATOPT_SHORT_ROWS: __spmv_csr_scalar(A, x, y, init);
break;
61 default: __spmv_csr_vector(A, x, y, init);
65template <
typename Matrix,
typename Vector>
66void __spmv_csr_scalar(
const Matrix& A,
const Vector& x, Vector& y,
68 using size_type =
typename Matrix::size_type;
69 using index_type =
typename Matrix::index_type;
70 using value_type =
typename Matrix::value_type;
72 const size_type BLOCK_SIZE = 256;
73 const size_type NUM_BLOCKS = Impl::ceil_div<size_type>(A.nrows(), BLOCK_SIZE);
75 const index_type* I = A.crow_offsets().data();
76 const index_type* J = A.ccolumn_indices().data();
77 const value_type* V = A.cvalues().data();
79 const value_type* x_ptr = x.data();
80 value_type* y_ptr = y.data();
83 y.assign(y.size(), 0);
86 Morpheus::Impl::Kernels::spmv_csr_scalar_kernel<size_type, index_type,
88 <<<NUM_BLOCKS, BLOCK_SIZE, 0>>>(A.nrows(), I, J, V, x_ptr, y_ptr);
89#if defined(DEBUG) || defined(MORPHEUS_DEBUG)
90 getLastCudaError(
"spmv_csr_scalar_kernel: Kernel execution failed");
94template <
size_t THREADS_PER_VECTOR,
typename Matrix,
typename Vector>
95void __spmv_csr_vector_dispatch(
const Matrix& A,
const Vector& x, Vector& y,
97 using size_type =
typename Matrix::size_type;
98 using index_type =
typename Matrix::index_type;
99 using value_type =
typename Matrix::value_type;
101 const index_type* I = A.crow_offsets().data();
102 const index_type* J = A.ccolumn_indices().data();
103 const value_type* V = A.cvalues().data();
105 const value_type* x_ptr = x.data();
106 value_type* y_ptr = y.data();
108 const size_type THREADS_PER_BLOCK = 128;
109 const size_type VECTORS_PER_BLOCK = THREADS_PER_BLOCK / THREADS_PER_VECTOR;
112 y.assign(y.size(), 0);
115 const size_type MAX_BLOCKS = max_active_blocks(
116 Kernels::spmv_csr_vector_kernel<size_type, index_type, value_type,
117 VECTORS_PER_BLOCK, THREADS_PER_VECTOR>,
118 THREADS_PER_BLOCK, 0);
120 const size_type NUM_BLOCKS = std::min<size_type>(
121 MAX_BLOCKS, Impl::ceil_div<size_type>(A.nrows(), VECTORS_PER_BLOCK));
123 Kernels::spmv_csr_vector_kernel<size_type, index_type, value_type,
124 VECTORS_PER_BLOCK, THREADS_PER_VECTOR>
125 <<<NUM_BLOCKS, THREADS_PER_BLOCK, 0>>>(A.nrows(), I, J, V, x_ptr, y_ptr);
127#if defined(DEBUG) || defined(MORPHEUS_DEBUG)
128 getLastCudaError(
"spmv_csr_vector_kernel: Kernel execution failed");
132template <
typename Matrix,
typename Vector>
133void __spmv_csr_vector(
const Matrix& A,
const Vector& x, Vector& y,
135 using size_type =
typename Matrix::size_type;
137 const size_type nnz_per_row = A.nnnz() / A.nrows();
139 if (nnz_per_row <= 2) {
140 __spmv_csr_vector_dispatch<2>(A, x, y, init);
143 if (nnz_per_row <= 4) {
144 __spmv_csr_vector_dispatch<4>(A, x, y, init);
147 if (nnz_per_row <= 8) {
148 __spmv_csr_vector_dispatch<8>(A, x, y, init);
151 if (nnz_per_row <= 16) {
152 __spmv_csr_vector_dispatch<16>(A, x, y, init);
156 if (nnz_per_row <= 32) {
157 __spmv_csr_vector_dispatch<32>(A, x, y, init);
161 __spmv_csr_vector_dispatch<64>(A, x, y, init);
Generic Morpheus interfaces.
Definition: dummy.cpp:24