7 #ifndef GHOST_TSMTTSM_CU_KERNEL_H
8 #define GHOST_TSMTTSM_CU_KERNEL_H
12 #include "ghost/config.h"
19 void *d_temp_storage = NULL;
20 size_t temp_storage_bytes = 0;
21 template<
typename oT,
int M,
int N>
22 __global__
void deviceReduce(
23 oT *blockResults, oT *result, oT alpha, oT beta,
int blockCount,
int lda,
int ldb,
int ldc)
25 int tidx = blockDim.x * blockIdx.x + threadIdx.x;
26 if (tidx >= M *
N)
return;
32 for (
int i = 0; i < blockCount; i++) {
33 sum =
accu(sum, blockResults[i * N * M + n * M + m]);
36 result[n * ldc + m] =
accu(
scale(result[n * ldc + m], beta),
scale2(sum, alpha));
39 template<
typename T,
bool conjv,
bool TRANSPOSE>
40 __device__ T condConj1(T v)
42 if (conjv && !TRANSPOSE) v =
conj(v);
45 template<
typename T,
bool conjv,
bool TRANSPOSE>
46 __device__ T condConj2(T v)
48 if (conjv && TRANSPOSE) v =
conj(v);
53 __device__
int roundPoT(
int v)
65 template<
typename T,
typename oT,
int conjv,
int M,
int N,
int BLOCKSIZE,
bool TRANSPOSE,
bool SELF>
66 __global__
void __launch_bounds__(BLOCKSIZE) genv7_blockProductKernel(
67 const T *A, const T *B, oT *out, const
int K, const
int lda, const
int ldb, const
int ldc)
69 const int rowsPerBlock = BLOCKSIZE / M;
70 int m = threadIdx.x % M;
71 int localRow = threadIdx.x / M;
72 int bOffset = localRow * ldb + m;
73 int aOffset = localRow * lda + m;
74 if (m >= N) bOffset = localRow * ldb + 0;
75 if (bOffset >= rowsPerBlock * ldb) bOffset = 0;
76 if (aOffset >= rowsPerBlock * lda) aOffset = 0;
78 __shared__ oT blockStorage[rowsPerBlock * M * (
sizeof(T) >
sizeof(oT) ? 2 : 1)];
79 T *rowCache =
reinterpret_cast<T *
>(blockStorage);
81 zero(blockStorage[threadIdx.x]);
85 for (
int n = 0; n <
N; n++) {
90 int idx = blockIdx.x * rowsPerBlock;
91 T avNow = __ldg(A + idx * lda + aOffset);
92 T bvNow = __ldg(B + idx * ldb + bOffset);
98 for (; idx < K - rowsPerBlock; idx += gridDim.x * rowsPerBlock) {
99 int idxNext = min(K - rowsPerBlock, idx + gridDim.x * rowsPerBlock);
100 avNext = __ldg(A + idxNext * lda + aOffset);
103 bvNext = __ldg(B + idxNext * ldb + bOffset);
108 rowCache[threadIdx.x] = bvNow;
111 int localAddress = threadIdx.x - m;
112 for (
int n = 0; n <
N; n++) {
113 threadSum[n] =
axpy(threadSum[n], condConj1<oT, conjv, TRANSPOSE>((oT)avNow),
114 condConj2<oT, conjv, TRANSPOSE>((oT)rowCache[localAddress + n]));
121 for (idx = idx + localRow; idx < K; idx += gridDim.x * rowsPerBlock) {
122 T av = A[idx * lda + m];
123 for (
int n = 0; n <
N; n++) {
124 threadSum[n] =
axpy(threadSum[n], condConj1<oT, conjv, TRANSPOSE>((oT)av),
125 condConj2<oT, conjv, TRANSPOSE>((oT)B[idx * ldb + n]));
129 const int redSteps = roundPoT(rowsPerBlock);
132 for (
int n = 0; n <
N; n++) {
134 blockStorage[threadIdx.x] = threadSum[n];
137 for (
unsigned int s = redSteps; s > 0; s /= 2) {
138 if (localRow < s && localRow < rowsPerBlock - s) {
139 blockStorage[localRow * M + m] =
140 accu(blockStorage[localRow * M + m], blockStorage[(localRow + s) * M + m]);
145 if (threadIdx.x < M) {
147 out[blockIdx.x * M * N + m * N + n] = blockStorage[m];
149 out[blockIdx.x * N * M + n * M + m] = blockStorage[m];
156 template<
typename T,
typename oT,
int M,
int N,
int conjv>
163 const int targetBlockSize = 256;
165 cudaGetDevice(&deviceUsed);
172 int const blockSize = (targetBlockSize /
N) * N;
173 CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks,
174 genv7_blockProductKernel<T, oT, conjv, M, N, blockSize, true, false>, blockSize, 0),
177 int const blockSize = (targetBlockSize / M) * M;
178 if (M == N && A == B) {
179 CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks,
180 genv7_blockProductKernel<T, oT, conjv, M, N, blockSize, false, true>,
184 CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks,
185 genv7_blockProductKernel<T, oT, conjv, M, N, blockSize, false, false>,
190 int blockCount = min( prop.multiProcessorCount * numBlocks, K*N / 10 / targetBlockSize + 1);
193 size_t required_temp_storage_bytes = M * N * blockCount *
sizeof(oT);
198 int const blockSize = (targetBlockSize /
N) * N;
199 genv7_blockProductKernel<T, oT, conjv, N, M, blockSize, true, false>
200 <<<blockCount, blockSize>>>(B, A, (oT *)d_temp_storage, K, ldb, lda, ldc);
202 int const blockSize = (targetBlockSize / M) * M;
203 if (M == N && A == B) {
204 genv7_blockProductKernel<T, oT, conjv, M, N, blockSize, false, true>
205 <<<blockCount, blockSize>>>(A, B, (oT *)d_temp_storage, K, lda, ldb, ldc);
207 genv7_blockProductKernel<T, oT, conjv, M, N, blockSize, false, false>
208 <<<blockCount, blockSize>>>(A, B, (oT *)d_temp_storage, K, lda, ldb, ldc);
213 deviceReduce<oT, M, N>
214 <<<(M *
N) / 256 + 1, 256>>>((oT *)d_temp_storage, C, alpha, beta, blockCount, lda, ldb, ldc);
Header file for type definitions.
__device__ T conj(T x)
Definition: cu_complex.h:226
#define CUDA_CALL(call, __err)
Definition: error.h:238
No error occured.
Definition: error.h:27
int32_t ghost_lidx
Definition: types.h:503
__device__ T scale(T y, T a)
Definition: cu_complex.h:275
__device__ T1 scale2(T1 y, T2 a)
Definition: cu_complex.h:293
ghost_error
Error return type.
Definition: error.h:23
__device__ __host__ void zero(T &val)
Definition: cu_complex.h:12
#define N
Definition: bench.c:20
__device__ t accu(t val, t val2)
Definition: cu_complex.h:103
Inline template functions for CUDA complex number handling.
ghost_error ghost_cu_temp_buffer_free(void *mem)
Frees memory allocated with ghost_cu_temp_buffer_malloc. Threadsafe.
Definition: cu_temp_buffer_malloc.cpp:65
__device__ T axpy(T val, T val2, T2 val3)
Definition: cu_complex.h:122
static ghost_error ghost_tsmttsm_cu_rm(oT *const __restrict__ C, const T *const __restrict__ A, const T *const __restrict__ B, const oT alpha, const oT beta, ghost_lidx K, ghost_lidx ldc, ghost_lidx lda, ghost_lidx ldb)
Definition: tsmttsm_cu_kernel.h:157
ghost_error ghost_cu_deviceprop_get(ghost_cu_deviceprop *prop)
Get the CUDA device properties.
Definition: cu_util.c:517
ghost_error ghost_cu_temp_buffer_malloc(void **mem, size_t bytesize)
Useful for allocating small temporary buffers. Keeps a list of previously allocated and freed buffers...
Definition: cu_temp_buffer_malloc.cpp:27