diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-06-11 17:39:01 +0300 |
---|---|---|
committer | Georgi Gerganov <ggerganov@gmail.com> | 2024-06-16 20:32:49 +0300 |
commit | 19b7a836f6658e18e973af532a5cc6ad6b3a27f8 (patch) | |
tree | dce35ebb1930615fc0deae9d68c8a0c366a8104d | |
parent | b5fcf8ef5c29df53cfff60e180b4992a3b2332a6 (diff) |
cuda : fix bounds check for src0 rows in MMVQ kernel (whisper/2231)
* cuda : fix bounds check for src0 rows in MMVQ kernel
* Update ggml-cuda/mmvq.cu
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
-rw-r--r-- | ggml-cuda/mmvq.cu | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index 5f056e91..e8d15716 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -117,7 +117,7 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; } } |