diff options
Diffstat (limited to 'ggml/src/ggml-cuda/mmvq.cu')
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index f87ebb96..bc26cce4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -147,10 +147,27 @@ static __global__ void mul_mat_vec_q( const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; + char * cdst = (char *)dst + i2*nb2; int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int rows_per_cuda_block = 1; +#else + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + const int row0 = rows_per_cuda_block*blockIdx.x; + if (threadIdx.y == 0) { + dst = (float *)cdst; + for (int j = 0; j < ncols_y; ++j) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + dst[j*nrows_dst + row0 + threadIdx.x] = 0; + } + } + } + return; + } const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; - char * cdst = (char *)dst + i2*nb2; mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } |