Morpheus 1.0.0
Dynamic matrix type and algorithms for sparse matrices
Loading...
Searching...
No Matches
Coo/Kernels/Morpheus_Multiply_Impl.hpp
1
25#ifndef MORPHEUS_COO_KERNELS_MULTIPLY_IMPL_HPP
26#define MORPHEUS_COO_KERNELS_MULTIPLY_IMPL_HPP
27
28#include <Morpheus_Macros.hpp>
29#if defined(MORPHEUS_ENABLE_CUDA) || defined(MORPHEUS_ENABLE_HIP)
30
31namespace Morpheus {
32namespace Impl {
33
34namespace Kernels {
35
36// forward decl
37template <typename IndexType, typename ValueType>
38__device__ void segreduce_block(const IndexType* idx, ValueType* val);
39
40// COO format SpMV kernel that uses only one thread
41// This is incredibly slow, so it is only useful for testing purposes,
42// *extremely* small matrices, or a few elements at the end of a
43// larger matrix
44template <typename SizeType, typename IndexType, typename ValueType>
45__global__ void spmv_coo_serial_kernel(const SizeType nnnz, const IndexType* I,
46 const IndexType* J, const ValueType* V,
47 const ValueType* x, ValueType* y) {
48 for (SizeType n = 0; n < nnnz; n++) {
49 y[I[n]] += V[n] * x[J[n]];
50 }
51}
52
53// spmv_coo_flat_kernel
54//
55// In this kernel each warp processes an interval of the nonzero values.
56// For example, if the matrix contains 128 nonzero values and there are
57// two warps and interval_size is 64, then the first warp (warp_id == 0)
58// will process the first set of 64 values (interval [0, 64)) and the
59// second warp will process // the second set of 64 values
60// (interval [64, 128)). Note that the number of nonzeros is not always
61// a multiple of 32 (the warp size) or 32 * the number of active warps,
62// so the last active warp will not always process a "full" interval of
63// interval_size.
64//
65// The first thread in each warp (thread_lane == 0) has a special role:
66// it is responsible for keeping track of the "carry" values from one
67// iteration to the next. The carry values consist of the row index and
68// partial sum from the previous batch of 32 elements. In the example
69// mentioned before with two warps and 128 nonzero elements, the first
70// warp iterates twice and looks at the carry of the first iteration to
71// decide whether to include this partial sum into the current batch.
72// Specifically, if a row extends over a 32-element boundary, then the
73// partial sum is carried over into the new 32-element batch. If,
74// on the other hand, the _last_ row index of the previous batch (the carry)
75// differs from the _first_ row index of the current batch (the row
76// read by the thread with thread_lane == 0), then the partial sum
77// is written out to memory.
78//
79// Each warp iterates over its interval, processing 32 elements at a time.
80// For each batch of 32 elements, the warp does the following
81// 1) Fetch the row index, column index, and value for a matrix entry. These
82// values are loaded from I[n], J[n], and V[n] respectively.
83// The row entry is stored in the shared memory array idx.
84// 2) Fetch the corresponding entry from the input vector. Specifically, for a
85// nonzero entry (i,j) in the matrix, the thread must load the value x[j]
86// from memory. We use the function fetch_x to control whether the texture
87// cache is used to load the value (UseCache == True) or whether a normal
88// global load is used (UseCache == False).
89// 3) The matrix value A(i,j) (which was stored in V[n]) is multiplied by the
90// value x[j] and stored in the shared memory array val.
91// 4) The first thread in the warp (thread_lane == 0) considers the "carry"
92// row index and either includes the carried sum in its own sum, or it
93// updates the output vector (y) with the carried sum.
94// 5) With row indices in the shared array idx and sums in the shared array
95// val, the warp conducts a segmented scan. The segmented scan operation
96// looks at the row entries for each thread (stored in idx) to see whether
97// two values belong to the same segment (segments correspond to matrix
98// rows). Consider the following example which consists of 3 segments (note:
99// this example uses a warp size of 16 instead of the usual 32)
100//
101// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # thread_lane
102// idx [ 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2] # row indices
103// val [ 4, 6, 5, 0, 8, 3, 2, 8, 3, 1, 4, 9, 2, 5, 2, 4] # A(i,j) * x(j)
104//
105// After the segmented scan the result will be
106//
107// val [ 4,10,15,15,23,26, 2,10,13,14, 4,13,15,20,22,26] # A(i,j) * x(j)
108//
109// 6) After the warp computes the segmented scan operation
110// each thread except for the last (thread_lane == 31) looks
111// at the row index of the next thread (threadIdx.x + 1) to
112// see if the segment ends here, or continues into the
113// next thread. The thread at the end of the segment writes
114// the sum into the output vector (y) at the corresponding row
115// index.
116// 7) The last thread in each warp (thread_lane == 31) writes
117// its row index and partial sum into the designated spote in the
118// carry_idx and carry_val arrays. The carry arrays are indexed
119// by warp_lane which is a number in [0, BLOCK_SIZE / 32).
120//
121// These steps are repeated until the warp reaches the end of its interval.
122// The carry values at the end of each interval are written to arrays
123// temp_rows and temp_vals, which are processed by a second kernel.
124//
125template <typename SizeType, typename IndexType, typename ValueType,
126 size_t BLOCK_SIZE>
127__launch_bounds__(BLOCK_SIZE, 1) __global__
128 void spmv_coo_flat_kernel(const SizeType nnnz, const SizeType interval_size,
129 const IndexType* I, const IndexType* J,
130 const ValueType* V, const ValueType* x,
131 ValueType* y, IndexType* temp_rows,
132 ValueType* temp_vals) {
133 const SizeType MID_LANE = WARP_SIZE / 2;
134 const SizeType LAST_LANE = WARP_SIZE - 1;
135
136 __shared__ volatile IndexType
137 rows[(WARP_SIZE + MID_LANE) * (BLOCK_SIZE / WARP_SIZE)];
138 __shared__ volatile ValueType vals[BLOCK_SIZE];
139
140 const SizeType thread_id =
141 BLOCK_SIZE * blockIdx.x + threadIdx.x; // global thread index
142 const SizeType thread_lane =
143 threadIdx.x & (WARP_SIZE - 1); // thread index within the warp
144 const SizeType warp_id = thread_id / WARP_SIZE; // global warp index
145
146 const SizeType interval_begin =
147 warp_id * interval_size; // warp's offset into I,J,V
148 const SizeType interval_end = Morpheus::Impl::min(
149 interval_begin + interval_size, nnnz); // end of warps's work
150
151 const SizeType idx = MID_LANE * (threadIdx.x / WARP_SIZE + 1) +
152 threadIdx.x; // thread's index into padded rows array
153
154 rows[idx - MID_LANE] = -1; // fill padding with invalid row index
155
156 if (interval_begin >= interval_end) // warp has no work to do
157 return;
158
159 if (thread_lane == WARP_SIZE - 1) {
160 // initialize the carry in values
161 rows[idx] = I[interval_begin];
162 vals[threadIdx.x] = ValueType(0);
163 }
164
165 for (SizeType n = interval_begin + thread_lane; n < interval_end;
166 n += WARP_SIZE) {
167 IndexType row = I[n]; // row index (i)
168 ValueType val = V[n] * x[J[n]]; // A(i,j) * x(j)
169
170 if (thread_lane == 0) {
171 if (row == rows[idx + LAST_LANE])
172 val += ValueType(vals[threadIdx.x + LAST_LANE]); // row continues
173 else
174 y[rows[idx + LAST_LANE]] +=
175 ValueType(vals[threadIdx.x + LAST_LANE]); // row terminated
176 }
177
178 rows[idx] = row;
179 vals[threadIdx.x] = val;
180
181 if (row == rows[idx - 1]) {
182 vals[threadIdx.x] = val += ValueType(vals[threadIdx.x - 1]);
183 }
184 if (row == rows[idx - 2]) {
185 vals[threadIdx.x] = val += ValueType(vals[threadIdx.x - 2]);
186 }
187 if (row == rows[idx - 4]) {
188 vals[threadIdx.x] = val += ValueType(vals[threadIdx.x - 4]);
189 }
190 if (row == rows[idx - 8]) {
191 vals[threadIdx.x] = val += ValueType(vals[threadIdx.x - 8]);
192 }
193 if (row == rows[idx - 16]) {
194 vals[threadIdx.x] = val += ValueType(vals[threadIdx.x - 16]);
195 }
196
197#if defined(MORPHEUS_ENABLE_HIP)
198 if (row == rows[idx - 32]) {
199 vals[threadIdx.x] = val += ValueType(vals[threadIdx.x - 32]);
200 }
201#endif // MORPHEUS_ENABLE_HIP
202
203 if (thread_lane < LAST_LANE && row != rows[idx + 1])
204 y[row] += ValueType(vals[threadIdx.x]); // row terminated
205 }
206
207 if (thread_lane == LAST_LANE) {
208 // write the carry out values
209 temp_rows[warp_id] = IndexType(rows[idx]);
210 temp_vals[warp_id] = ValueType(vals[threadIdx.x]);
211 }
212}
213
214// The second level of the segmented reduction operation
215template <typename SizeType, typename IndexType, typename ValueType,
216 size_t BLOCK_SIZE>
217__launch_bounds__(BLOCK_SIZE, 1) __global__
218 void spmv_coo_reduce_update_kernel(const SizeType num_warps,
219 const IndexType* temp_rows,
220 const ValueType* temp_vals,
221 ValueType* y) {
222 __shared__ IndexType rows[BLOCK_SIZE + 1];
223 __shared__ ValueType vals[BLOCK_SIZE + 1];
224
225 const SizeType end = num_warps - (num_warps & (BLOCK_SIZE - 1));
226
227 if (threadIdx.x == 0) {
228 rows[BLOCK_SIZE] = (IndexType)-1;
229 vals[BLOCK_SIZE] = (ValueType)0;
230 }
231
232 __syncthreads();
233
234 SizeType i = threadIdx.x;
235
236 while (i < end) {
237 // do full blocks
238 rows[threadIdx.x] = temp_rows[i];
239 vals[threadIdx.x] = temp_vals[i];
240
241 __syncthreads();
242
243 segreduce_block(rows, vals);
244
245 if (rows[threadIdx.x] != rows[threadIdx.x + 1])
246 y[rows[threadIdx.x]] += vals[threadIdx.x];
247
248 __syncthreads();
249
250 i += BLOCK_SIZE;
251 }
252
253 if (end < num_warps) {
254 if (i < num_warps) {
255 rows[threadIdx.x] = temp_rows[i];
256 vals[threadIdx.x] = temp_vals[i];
257 } else {
258 rows[threadIdx.x] = (IndexType)-1;
259 vals[threadIdx.x] = (ValueType)0;
260 }
261
262 __syncthreads();
263
264 segreduce_block(rows, vals);
265
266 if (i < num_warps)
267 if (rows[threadIdx.x] != rows[threadIdx.x + 1])
268 y[rows[threadIdx.x]] += vals[threadIdx.x];
269 }
270}
271
272template <typename IndexType, typename ValueType>
273__device__ void segreduce_block(const IndexType* idx, ValueType* val) {
274 ValueType left = 0;
275 if (threadIdx.x >= 1 && idx[threadIdx.x] == idx[threadIdx.x - 1]) {
276 left = val[threadIdx.x - 1];
277 }
278 __syncthreads();
279 val[threadIdx.x] += left;
280 left = 0;
281 __syncthreads();
282 if (threadIdx.x >= 2 && idx[threadIdx.x] == idx[threadIdx.x - 2]) {
283 left = val[threadIdx.x - 2];
284 }
285 __syncthreads();
286 val[threadIdx.x] += left;
287 left = 0;
288 __syncthreads();
289 if (threadIdx.x >= 4 && idx[threadIdx.x] == idx[threadIdx.x - 4]) {
290 left = val[threadIdx.x - 4];
291 }
292 __syncthreads();
293 val[threadIdx.x] += left;
294 left = 0;
295 __syncthreads();
296 if (threadIdx.x >= 8 && idx[threadIdx.x] == idx[threadIdx.x - 8]) {
297 left = val[threadIdx.x - 8];
298 }
299 __syncthreads();
300 val[threadIdx.x] += left;
301 left = 0;
302 __syncthreads();
303 if (threadIdx.x >= 16 && idx[threadIdx.x] == idx[threadIdx.x - 16]) {
304 left = val[threadIdx.x - 16];
305 }
306 __syncthreads();
307 val[threadIdx.x] += left;
308 left = 0;
309 __syncthreads();
310 if (threadIdx.x >= 32 && idx[threadIdx.x] == idx[threadIdx.x - 32]) {
311 left = val[threadIdx.x - 32];
312 }
313 __syncthreads();
314 val[threadIdx.x] += left;
315 left = 0;
316 __syncthreads();
317 if (threadIdx.x >= 64 && idx[threadIdx.x] == idx[threadIdx.x - 64]) {
318 left = val[threadIdx.x - 64];
319 }
320 __syncthreads();
321 val[threadIdx.x] += left;
322 left = 0;
323 __syncthreads();
324 if (threadIdx.x >= 128 && idx[threadIdx.x] == idx[threadIdx.x - 128]) {
325 left = val[threadIdx.x - 128];
326 }
327 __syncthreads();
328 val[threadIdx.x] += left;
329 left = 0;
330 __syncthreads();
331 if (threadIdx.x >= 256 && idx[threadIdx.x] == idx[threadIdx.x - 256]) {
332 left = val[threadIdx.x - 256];
333 }
334 __syncthreads();
335 val[threadIdx.x] += left;
336 left = 0;
337 __syncthreads();
338}
339
340} // namespace Kernels
341} // namespace Impl
342} // namespace Morpheus
343
344#endif
345
346#endif // MORPHEUS_COO_KERNELS_MULTIPLY_IMPL_HPP
Generic Morpheus interfaces.
Definition: dummy.cpp:24