Morpheus 1.0.0
Dynamic matrix type and algorithms for sparse matrices
Loading...
Searching...
No Matches
Dia/Kernels/Morpheus_Multiply_Impl.hpp
1
25#ifndef MORPHEUS_DIA_KERNELS_MULTIPLY_IMPL_HPP
26#define MORPHEUS_DIA_KERNELS_MULTIPLY_IMPL_HPP
27
28#include <Morpheus_Macros.hpp>
29#if defined(MORPHEUS_ENABLE_CUDA) || defined(MORPHEUS_ENABLE_HIP)
30
31#include <impl/Morpheus_Utils.hpp>
32
33namespace Morpheus {
34namespace Impl {
35namespace Kernels {
36
37template <typename SizeType, typename IndexType, typename ValueType,
38 size_t BLOCK_SIZE>
39__launch_bounds__(BLOCK_SIZE, 1) __global__
40 void spmv_dia_kernel(const SizeType num_rows, const SizeType num_cols,
41 const SizeType num_diagonals, const SizeType pitch,
42 const IndexType* diagonal_offsets,
43 const ValueType* values, const ValueType* x,
44 ValueType* y) {
45 __shared__ IndexType offsets[BLOCK_SIZE];
46
47 const SizeType thread_id = BLOCK_SIZE * blockIdx.x + threadIdx.x;
48 const SizeType grid_size = BLOCK_SIZE * gridDim.x;
49
50 for (SizeType base = 0; base < num_diagonals; base += BLOCK_SIZE) {
51 // read a chunk of the diagonal offsets into shared memory
52 const SizeType chunk_size =
53 Morpheus::Impl::min(SizeType(BLOCK_SIZE), num_diagonals - base);
54
55 if (threadIdx.x < chunk_size)
56 offsets[threadIdx.x] = diagonal_offsets[base + threadIdx.x];
57
58 __syncthreads();
59
60 // process chunk
61 for (SizeType row = thread_id; row < num_rows; row += grid_size) {
62 ValueType sum = ValueType(0);
63
64 // index into values array
65 SizeType idx = row + pitch * base;
66
67 for (SizeType n = 0; n < chunk_size; n++) {
68 const IndexType col = row + offsets[n];
69
70 if (col >= 0 && col < (IndexType)num_cols) {
71 const ValueType A_ij = values[idx];
72 sum += A_ij * x[col];
73 }
74
75 idx += pitch;
76 }
77
78 y[row] += sum;
79 }
80
81 // wait until all threads are done reading offsets
82 __syncthreads();
83 }
84}
85
86} // namespace Kernels
87
88} // namespace Impl
89} // namespace Morpheus
90
91#endif
92
93#endif // MORPHEUS_DIA_KERNELS_MULTIPLY_IMPL_HPP
Generic Morpheus interfaces.
Definition: dummy.cpp:24