6 #ifndef GHOST_TSMM_CU_KERNEL_H
7 #define GHOST_TSMM_CU_KERNEL_H
9 #include "ghost/config.h"
13 #include <cublas_v2.h>
20 bool eq(
const T lhs,
const T rhs)
26 bool eq<cuDoubleComplex>(
const cuDoubleComplex lhs,
const cuDoubleComplex rhs)
28 return lhs.x == rhs.x && lhs.y == rhs.y;
32 bool eq<cuFloatComplex>(
const cuFloatComplex lhs,
const cuFloatComplex rhs)
34 return lhs.x == rhs.x && lhs.y == rhs.y;
37 template<
typename T,
typename iT,
int M,
int N,
int BLOCKSIZE,
bool BETAISZERO>
38 static __global__
void __launch_bounds__(BLOCKSIZE)
39 tsmm_fix_fb_kernel(const T *__restrict__ A, const iT *__restrict__ B, T *out, const
int K,
40 const
int lda, const
int ldb, const
int ldc, iT alpha, iT beta)
42 int tidx = blockIdx.x * BLOCKSIZE + threadIdx.x;
45 const bool fitsShm = (M *
N *
sizeof(iT) <= (1 << 14));
47 __shared__ iT bCache[fitsShm ? M : 1][fitsShm ?
N : 1];
50 for (
int mn = threadIdx.x; mn < M *
N; mn += BLOCKSIZE) {
53 bCache[tm][tn] = B[tn * ldb + tm];
58 if (tidx /
N == gridDim.x * BLOCKSIZE /
N && !BETAISZERO)
return;
61 for (; row < K / 2; row += gridDim.x * BLOCKSIZE /
N) {
66 const int o1 = row * lda;
67 const int o2 = (row + K / 2) * lda;
69 for (
int m = 0; m < M; m++) {
71 sum1 =
axpy(sum1, (iT)A[o1 + m], bV);
72 sum2 =
axpy(sum2, (iT)A[o2 + m], bV);
75 out[row * ldc + n] =
scale(alpha, sum1);
76 out[(row + K / 2) * ldc + n] =
scale(alpha, sum2);
78 out[row * ldc + n] =
axpby(sum1, (iT)out[row * ldc + n], alpha, beta);
79 out[(row + K / 2) * ldc + n] =
axpby(sum2, (iT)out[(row + K / 2) * ldc + n], alpha, beta);
84 for (row += K / 2; row < K; row += gridDim.x * BLOCKSIZE /
N) {
88 #pragma unroll(M <= 8 ? M : 1)
89 for (
int m = 0; m < M; m++) { sum =
axpy(sum, (iT)A[row * lda + m], bCache[m][n]); }
91 out[row * ldc + n] =
scale(alpha, sum);
93 out[row * ldc + n] =
axpby(sum, (iT)out[row * ldc + n], alpha, beta);
99 template<
typename T,
typename iT,
int M,
int N>
104 const bool fitsShm = (M * N *
sizeof(iT) <= (1 << 14));
105 if (!fitsShm)
return false;
107 const int threadsPerBlock = (M * N > 1024) ? (M * N > 55 ? 1024 : 512) : 256;
110 cudaGetDevice(&deviceUsed);
117 CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks,
118 tsmm_fix_fb_kernel<T, iT, M, N, threadsPerBlock, false>, threadsPerBlock, 0),
120 int blockCount = prop.multiProcessorCount * numBlocks;
123 if (eq(beta, Tzero)) {
124 tsmm_fix_fb_kernel<T, iT, M, N, threadsPerBlock, true>
125 <<<blockCount, threadsPerBlock>>>(A, B, C, K, lda, ldb, ldc, alpha, beta);
127 tsmm_fix_fb_kernel<T, iT, M, N, threadsPerBlock, false>
128 <<<blockCount, threadsPerBlock>>>(A, B, C, K, lda, ldb, ldc, alpha, beta);
Header file for type definitions.
#define CUDA_CALL(call, __err)
Definition: error.h:238
__device__ T axpby(T x, T y, T a, T b)
Definition: cu_complex.h:202
No error occured.
Definition: error.h:27
int32_t ghost_lidx
Definition: types.h:503
__device__ T scale(T y, T a)
Definition: cu_complex.h:275
bool ghost_tsmm_cu_rm_cm(T *C, const T *A, const iT *B, const iT alpha, const iT beta, const ghost_lidx K, const ghost_lidx ldc, const ghost_lidx lda, const ghost_lidx ldb)
Definition: tsmm_cu_kernel.h:100
ghost_error
Error return type.
Definition: error.h:23
__device__ __host__ void zero(T &val)
Definition: cu_complex.h:12
#define N
Definition: bench.c:20
Inline template functions for CUDA complex number handling.
__device__ T axpy(T val, T val2, T2 val3)
Definition: cu_complex.h:122
ghost_error ghost_cu_deviceprop_get(ghost_cu_deviceprop *prop)
Get the CUDA device properties.
Definition: cu_util.c:517