diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-10 16:16:51 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-10 16:16:51 +0200 |
commit | 699c9cb7f63dd8431bce91b86e10efb41255f6c1 (patch) | |
tree | 6000fd823e443f80f90ec490b1bbdf6461902924 | |
parent | b096a5de7a9bdf516bb20729d5d0a3b2a12cba2f (diff) |
Faster MoE token generation on CUDA (#248)
* This gives us ~20% TG speedup for DeepSeek on CUDA
* Slightly better
* Also do it for plain (not fused) mul_mat_id
* Guard against numerical precision issues for MLA on CUDA
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda.cu | 310 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 88 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 40 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 248 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cuh | 6 | ||||
-rw-r--r-- | src/llama.cpp | 3 |
6 files changed, 487 insertions, 208 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 410c6406..f25dd725 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1765,6 +1765,93 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } +/* +static void ggml_cuda_op_gemv_id( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src0_ids, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + quantize_cuda_t quantize_src1) { + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nrows(src1) == 1); + GGML_ASSERT(src0_ids->ne[1] == 1); + GGML_ASSERT(src0_ids->ne[0] <= dst->ne[2]); + GGML_ASSERT(dst->ne[1] == 1); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + + GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); + + ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + + int device_id = ctx.device; + GGML_ASSERT(src0_ctx->device == device_id); + GGML_ASSERT(src1_ctx->device == device_id); + GGML_ASSERT(dst_ctx->device == device_id); + + const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); + GGML_ASSERT(!split); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t nrows1 = 1; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne2 = dst->ne[2]; + + const int64_t nb2 = dst->nb[2]; + + // Why? + GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); + + const size_t src0_rs = ggml_row_size(src0->type, ne00); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + ggml_cuda_pool_alloc<char> src0_dd_alloc; + ggml_cuda_pool_alloc<float> src1_ddf_alloc; + ggml_cuda_pool_alloc<char> src1_ddq_alloc; + ggml_cuda_pool_alloc<float> dst_dd_alloc; + + char * src0_dd = nullptr; + float * src1_ddf = (float *)src1->data; + char * src1_ddq = nullptr; // q8_1 + float * dst_dd = (float *)dst->data; + + bool quantization_done = false; + + const bool src1_on_device = device_id == src1_ctx->device; + const bool dst_on_device = device_id == dst_ctx->device; + + ggml_cuda_set_device(device_id); + cudaStream_t stream = ctx.stream(device_id, 0); + + src0_dd = (char *) src0->data; + + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + src1_ddq = src1_ddq_alloc.alloc(ctx.pool(device_id), src_1_ddq_size); + quantize_src1(src1_ddf, src1_ddq, ne10, 1, 1, src1_padded_col_size, src0->type, stream); + } + + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, src1, src0_ids, dst, + (const char *)src0->data, (const float *)src1->data, src1_ddq, (float *)dst->data, + 0, ne01, 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + +} +*/ + static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); @@ -2090,6 +2177,52 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0->type) && + ggml_backend_buffer_is_cuda(src0->buffer) && + ggml_backend_buffer_is_cuda(src1->buffer) && + ggml_backend_buffer_is_cuda(dst->buffer) && + !ggml_backend_buffer_is_cuda_split(src0->buffer) && + src1->type == GGML_TYPE_F32) { + int device_id = ctx.device; + ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + if (src0_ctx->device == device_id && + src1_ctx->device == device_id && + dst_ctx->device == device_id) { + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + // Fast TG path + const int64_t n_ids = ids->ne[0]; + auto stream = ctx.stream(device_id, 0); + + auto local_dst = *dst; + local_dst.ne[2] = n_ids; + local_dst.ne[1] = local_dst.ne[3] = 1; + local_dst.nb[2] = local_dst.nb[1]; + + auto local_src1 = *src1; + local_src1.nb[2] = local_src1.nb[3] = 0; + + const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool()); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size); + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, + src0->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_src1.nb[1] = src_1_ddq_size; + + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst, + (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, + 0, src0->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + return; + } + } + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); @@ -2232,6 +2365,121 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * src1 = dst->src[2]; const ggml_tensor * ids = dst->src[3]; + if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0_1->type) && + ggml_is_quantized(src0_2->type) && + ggml_backend_buffer_is_cuda(src0_1->buffer) && + ggml_backend_buffer_is_cuda(src0_2->buffer) && + ggml_backend_buffer_is_cuda(src1->buffer) && + ggml_backend_buffer_is_cuda(dst->buffer) && + !ggml_backend_buffer_is_cuda_split(src0_1->buffer) && + !ggml_backend_buffer_is_cuda_split(src0_2->buffer) && + src1->type == GGML_TYPE_F32) { + int device_id = ctx.device; + ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context; + ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + if (src0_1_ctx->device == device_id && + src0_2_ctx->device == device_id && + src1_ctx->device == device_id && + dst_ctx->device == device_id) { + // Fast TG path + const int64_t n_ids = ids->ne[0]; + auto stream = ctx.stream(device_id, 0); + ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); + ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); + + auto local_dst = *dst; + local_dst.ne[2] = n_ids; + local_dst.ne[1] = local_dst.ne[3] = 1; + local_dst.nb[1] = local_dst.nb[2] = local_dst.nb[3] = local_dst.ne[0]*sizeof(float); + + auto local_src1 = *src1; + local_src1.nb[2] = local_src1.nb[3] = 0; + + const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool()); + if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) { + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size); + // Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda. + // If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, + src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_src1.nb[1] = src_1_ddq_size; + } + + local_dst.data = dst_up_contiguous.get(); + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, + (const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(), + 0, src0_1->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + local_dst.data = dst_gate_contiguous.get(); + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, + (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), + 0, src0_2->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && + ((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id && + ggml_backend_buffer_is_cuda(next->buffer) && + ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); + GGML_ASSERT(dst->ne[0] % QK8_1 == 0); + auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1; + auto dst_ddq_size = n_ids*dst_row_size; + ggml_cuda_pool_alloc<char> dst_quantized(ctx.pool(), dst_ddq_size); + quantize_row_q8_1_cuda((const float *)dst_gate_contiguous.get(), (void *)dst_quantized.get(), dst->ne[0], n_ids, 1, + dst_padded_col_size, next->src[0]->type, stream); + CUDA_CHECK(cudaGetLastError()); + + std::vector<char> ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + local_dst.ne[2] = 1; + + auto local_next = *next; + local_next.ne[2] = local_next.ne[1]; + local_next.ne[1] = local_next.ne[3] = 1; + local_next.nb[2] = local_next.nb[1]; + + local_src1 = *next->src[1]; + local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1; + local_src1.nb[1] = local_src1.nb[2] = local_src1.nb[3] = dst_row_size; + + auto local_src0 = *next->src[0]; + local_src0.ne[2] = local_src0.ne[3] = 1; + + ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, + (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data, + 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + return true; + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + CUDA_CHECK(cudaGetLastError()); + return false; + } + } + } + + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); @@ -2299,49 +2547,47 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor if (fuse_down) { final_dst.src[1] = &dst_row; } - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]); - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; + const int64_t i11 = id % ne11; + const int64_t i12 = 0; - const int64_t i1 = id; - const int64_t i2 = i12; + const int64_t i1 = id; + const int64_t i2 = i12; - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - //dst_row.data = dst_original + i1*nb1 + i2*nb2; + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + //dst_row.data = dst_original + i1*nb1 + i2*nb2; - dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); - if (fuse_down) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); - CUDA_CHECK(cudaGetLastError()); + if (fuse_down) { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); - } else { + } else { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); - CUDA_CHECK(cudaGetLastError()); + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); + CUDA_CHECK(cudaGetLastError()); - } } } } else { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 2eafe463..576c387d 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -96,11 +96,13 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y> __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __global__ void iqk_mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size, - const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) { + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; - const char * cx = (const char *)vx + i2*nb02; + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) return; + const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; char * cdst = (char *)dst + i2*nb2; iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); @@ -108,9 +110,9 @@ __global__ void iqk_mul_mat_vec_q( template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> void iqk_mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); //GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -152,28 +154,28 @@ void iqk_mul_mat_vec_q_cuda( switch (ncols_y) { case 1: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 2: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 3: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 4: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 5: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 6: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 7: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 8: - iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; default: GGML_ASSERT(false); @@ -742,79 +744,79 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( } // namespace void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq3_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_kss_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq6_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index a128bc53..15df20f5 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -1,51 +1,51 @@ #include "common.cuh" void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq3_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq6_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_kss_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e17e77a3..b9e9c216 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -134,11 +134,12 @@ template <ggml_type type, int ncols_y> __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, - const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) { + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; - const char * cx = (const char *)vx + i2*nb02; + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; char * cdst = (char *)dst + i2*nb2; mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); @@ -146,9 +147,9 @@ static __global__ void mul_mat_vec_q( template <ggml_type type> static void mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -188,28 +189,28 @@ static void mul_mat_vec_q_cuda( switch (ncols_y) { case 1: - mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 2: - mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 3: - mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 4: - mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 5: - mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 6: - mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 7: - mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 8: - mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; default: GGML_ABORT("fatal error"); @@ -218,169 +219,169 @@ static void mul_mat_vec_q_cuda( } static void mul_mat_vec_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q6_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq3_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq1_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq1_m_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq4_nl_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq4_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq3_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, const int64_t ne00, const int64_t ne0, const int64_t ne2, - const int64_t nb02, const int64_t nb12, const int64_t nb2, - const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, + const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0, + const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream) { @@ -391,98 +392,97 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); switch (type) { case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q6_0: - mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ1_BN: - mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_BN: - mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_K: - mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ3_K: - mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_K: - mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_KS: - mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_KSS: - mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_KS: - mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ5_K: - mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ6_K: - mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; default: GGML_ABORT("fatal error"); @@ -505,14 +505,12 @@ void ggml_cuda_op_mul_mat_vec_q_3D( const int64_t ne0 = dst->ne[0]; - int id = ggml_cuda_get_device(); - const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, dst->ne[2], - src0->nb[2], src1_row_size, dst->nb[2], - src0_dd_i, src1_ddq_i, dst_dd_i, + src0->nb[2], src1_row_size, dst->nb[2], 0, + src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, row_low, row_high, src1_ncols, src1_padded_row_size, stream); @@ -531,11 +529,35 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t ne0 = dst->ne[0]; - int id = ggml_cuda_get_device(); + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne0, 1, 0, 0, 0, 0, + src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); + + GGML_UNUSED(src1_ddf_i); +} + +void ggml_cuda_op_mul_mat_vec_q_id( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); + GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1); + GGML_ASSERT(ids->ne[0] == dst->ne[2]); + + const int64_t ne0 = dst->ne[0]; ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, - ne00, ne0, 1, 0, 0, 0, - src0_dd_i, src1_ddq_i, dst_dd_i, + ne00, ne0, dst->ne[2], + src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], + src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data, row_low, row_high, src1_ncols, src1_padded_row_size, stream); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index f86aefe2..c0699b04 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -11,3 +11,9 @@ void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); + +void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/src/llama.cpp b/src/llama.cpp index 7d665072..bad8d33d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13734,6 +13734,9 @@ struct llm_build_context { } ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + if (kv_cache->ne[1] < 256) { + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } cb(kq, "kq", il); if (!pp_opt) { |