summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu12
1 files changed, 9 insertions, 3 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 568c411a..b2211d85 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -5131,10 +5131,10 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
- for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
- const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index
+ for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
- const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
@@ -10918,6 +10918,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (a->ne[3] != b->ne[3]) {
return false;
}
+ ggml_type a_type = a->type;
+ if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
+ if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+ return false;
+ }
+ }
return true;
} break;
case GGML_OP_GET_ROWS: