summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda/mmvq.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda/mmvq.cu')
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu19
1 files changed, 18 insertions, 1 deletions
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index f87ebb96..bc26cce4 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -147,10 +147,27 @@ static __global__ void mul_mat_vec_q(
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
int i2 = blockIdx.y;
+ char * cdst = (char *)dst + i2*nb2;
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
+ if (i02 < 0) {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+ constexpr int rows_per_cuda_block = 1;
+#else
+ constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+ const int row0 = rows_per_cuda_block*blockIdx.x;
+ if (threadIdx.y == 0) {
+ dst = (float *)cdst;
+ for (int j = 0; j < ncols_y; ++j) {
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+ dst[j*nrows_dst + row0 + threadIdx.x] = 0;
+ }
+ }
+ }
+ return;
+ }
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
- char * cdst = (char *)dst + i2*nb2;
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
}