summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-10 16:16:51 +0200
committerGitHub <noreply@github.com>2025-03-10 16:16:51 +0200
commit699c9cb7f63dd8431bce91b86e10efb41255f6c1 (patch)
tree6000fd823e443f80f90ec490b1bbdf6461902924
parentb096a5de7a9bdf516bb20729d5d0a3b2a12cba2f (diff)
Faster MoE token generation on CUDA (#248)
* This gives us ~20% TG speedup for DeepSeek on CUDA * Slightly better * Also do it for plain (not fused) mul_mat_id * Guard against numerical precision issues for MLA on CUDA --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-cuda.cu310
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cu88
-rw-r--r--ggml/src/ggml-cuda/iqk_mmvq.cuh40
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu248
-rw-r--r--ggml/src/ggml-cuda/mmvq.cuh6
-rw-r--r--src/llama.cpp3
6 files changed, 487 insertions, 208 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 410c6406..f25dd725 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -1765,6 +1765,93 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
+/*
+static void ggml_cuda_op_gemv_id(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src0_ids, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+ quantize_cuda_t quantize_src1) {
+
+ GGML_ASSERT(src0->ne[3] == 1);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+ GGML_ASSERT(src0_ids->ne[1] == 1);
+ GGML_ASSERT(src0_ids->ne[0] <= dst->ne[2]);
+ GGML_ASSERT(dst->ne[1] == 1);
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
+
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
+
+ ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context;
+ ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+
+ int device_id = ctx.device;
+ GGML_ASSERT(src0_ctx->device == device_id);
+ GGML_ASSERT(src1_ctx->device == device_id);
+ GGML_ASSERT(dst_ctx->device == device_id);
+
+ const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
+ GGML_ASSERT(!split);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t nrows1 = 1;
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne2 = dst->ne[2];
+
+ const int64_t nb2 = dst->nb[2];
+
+ // Why?
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+
+ const size_t src0_rs = ggml_row_size(src0->type, ne00);
+ const size_t q8_1_ts = sizeof(block_q8_1);
+ const size_t q8_1_bs = QK8_1;
+
+ const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+ ggml_cuda_pool_alloc<char> src0_dd_alloc;
+ ggml_cuda_pool_alloc<float> src1_ddf_alloc;
+ ggml_cuda_pool_alloc<char> src1_ddq_alloc;
+ ggml_cuda_pool_alloc<float> dst_dd_alloc;
+
+ char * src0_dd = nullptr;
+ float * src1_ddf = (float *)src1->data;
+ char * src1_ddq = nullptr; // q8_1
+ float * dst_dd = (float *)dst->data;
+
+ bool quantization_done = false;
+
+ const bool src1_on_device = device_id == src1_ctx->device;
+ const bool dst_on_device = device_id == dst_ctx->device;
+
+ ggml_cuda_set_device(device_id);
+ cudaStream_t stream = ctx.stream(device_id, 0);
+
+ src0_dd = (char *) src0->data;
+
+ if (quantize_src1) {
+ size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+ src1_ddq = src1_ddq_alloc.alloc(ctx.pool(device_id), src_1_ddq_size);
+ quantize_src1(src1_ddf, src1_ddq, ne10, 1, 1, src1_padded_col_size, src0->type, stream);
+ }
+
+ ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, src1, src0_ids, dst,
+ (const char *)src0->data, (const float *)src1->data, src1_ddq, (float *)dst->data,
+ 0, ne01, 1, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+}
+*/
+
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
@@ -2090,6 +2177,52 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
+ if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
+ ggml_is_quantized(src0->type) &&
+ ggml_backend_buffer_is_cuda(src0->buffer) &&
+ ggml_backend_buffer_is_cuda(src1->buffer) &&
+ ggml_backend_buffer_is_cuda(dst->buffer) &&
+ !ggml_backend_buffer_is_cuda_split(src0->buffer) &&
+ src1->type == GGML_TYPE_F32) {
+ int device_id = ctx.device;
+ ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context;
+ ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+ if (src0_ctx->device == device_id &&
+ src1_ctx->device == device_id &&
+ dst_ctx->device == device_id) {
+ GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
+ // Fast TG path
+ const int64_t n_ids = ids->ne[0];
+ auto stream = ctx.stream(device_id, 0);
+
+ auto local_dst = *dst;
+ local_dst.ne[2] = n_ids;
+ local_dst.ne[1] = local_dst.ne[3] = 1;
+ local_dst.nb[2] = local_dst.nb[1];
+
+ auto local_src1 = *src1;
+ local_src1.nb[2] = local_src1.nb[3] = 0;
+
+ const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
+ ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
+ auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
+ local_src1.data = src1_quantized.alloc(src_1_ddq_size);
+ quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size,
+ src0->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ local_src1.nb[1] = src_1_ddq_size;
+
+ ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst,
+ (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
+ 0, src0->ne[1], 1, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ return;
+ }
+ }
+
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
@@ -2232,6 +2365,121 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
const ggml_tensor * src1 = dst->src[2];
const ggml_tensor * ids = dst->src[3];
+ if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
+ ggml_is_quantized(src0_1->type) &&
+ ggml_is_quantized(src0_2->type) &&
+ ggml_backend_buffer_is_cuda(src0_1->buffer) &&
+ ggml_backend_buffer_is_cuda(src0_2->buffer) &&
+ ggml_backend_buffer_is_cuda(src1->buffer) &&
+ ggml_backend_buffer_is_cuda(dst->buffer) &&
+ !ggml_backend_buffer_is_cuda_split(src0_1->buffer) &&
+ !ggml_backend_buffer_is_cuda_split(src0_2->buffer) &&
+ src1->type == GGML_TYPE_F32) {
+ int device_id = ctx.device;
+ ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context;
+ ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context;
+ ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+ if (src0_1_ctx->device == device_id &&
+ src0_2_ctx->device == device_id &&
+ src1_ctx->device == device_id &&
+ dst_ctx->device == device_id) {
+ // Fast TG path
+ const int64_t n_ids = ids->ne[0];
+ auto stream = ctx.stream(device_id, 0);
+ ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids);
+ ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids);
+
+ auto local_dst = *dst;
+ local_dst.ne[2] = n_ids;
+ local_dst.ne[1] = local_dst.ne[3] = 1;
+ local_dst.nb[1] = local_dst.nb[2] = local_dst.nb[3] = local_dst.ne[0]*sizeof(float);
+
+ auto local_src1 = *src1;
+ local_src1.nb[2] = local_src1.nb[3] = 0;
+
+ const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING);
+ ggml_cuda_pool_alloc<char> src1_quantized(ctx.pool());
+ if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) {
+ GGML_ASSERT(src1->ne[0] % QK8_1 == 0);
+ auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1;
+ local_src1.data = src1_quantized.alloc(src_1_ddq_size);
+ // Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda.
+ // If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type
+ quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size,
+ src0_1->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ local_src1.nb[1] = src_1_ddq_size;
+ }
+
+ local_dst.data = dst_up_contiguous.get();
+ ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst,
+ (const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(),
+ 0, src0_1->ne[1], 1, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ local_dst.data = dst_gate_contiguous.get();
+ ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
+ (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
+ 0, src0_2->ne[1], 1, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
+ ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
+ !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
+ ((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id &&
+ ggml_backend_buffer_is_cuda(next->buffer) &&
+ ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
+
+ ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids,
+ (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
+ CUDA_CHECK(cudaGetLastError());
+
+ const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
+ GGML_ASSERT(dst->ne[0] % QK8_1 == 0);
+ auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1;
+ auto dst_ddq_size = n_ids*dst_row_size;
+ ggml_cuda_pool_alloc<char> dst_quantized(ctx.pool(), dst_ddq_size);
+ quantize_row_q8_1_cuda((const float *)dst_gate_contiguous.get(), (void *)dst_quantized.get(), dst->ne[0], n_ids, 1,
+ dst_padded_col_size, next->src[0]->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ std::vector<char> ids_host(ggml_nbytes(ids));
+ const char * ids_dev = (const char *) ids->data;
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+ CUDA_CHECK(cudaStreamSynchronize(stream));
+
+ local_dst.ne[2] = 1;
+
+ auto local_next = *next;
+ local_next.ne[2] = local_next.ne[1];
+ local_next.ne[1] = local_next.ne[3] = 1;
+ local_next.nb[2] = local_next.nb[1];
+
+ local_src1 = *next->src[1];
+ local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1;
+ local_src1.nb[1] = local_src1.nb[2] = local_src1.nb[3] = dst_row_size;
+
+ auto local_src0 = *next->src[0];
+ local_src0.ne[2] = local_src0.ne[3] = 1;
+
+ ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next,
+ (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
+ 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ return true;
+ } else {
+ ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
+ (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
+ CUDA_CHECK(cudaGetLastError());
+ return false;
+ }
+ }
+ }
+
+
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");
@@ -2299,49 +2547,47 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
if (fuse_down) {
final_dst.src[1] = &dst_row;
}
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
- for (int64_t id = 0; id < n_ids; id++) {
- const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]);
- if (i02 < 0 || i02 >= n_as) continue;
- //GGML_ASSERT(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(i02 >= 0 && i02 < n_as);
- const int64_t i11 = id % ne11;
- const int64_t i12 = iid1;
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = 0;
- const int64_t i1 = id;
- const int64_t i2 = i12;
+ const int64_t i1 = id;
+ const int64_t i2 = i12;
- src0_1_row.data = src0_1_original + i02*nb02;
- src0_2_row.data = src0_2_original + i02*nb02;
- src1_row.data = src1_original + i11*nb11 + i12*nb12;
- //dst_row.data = dst_original + i1*nb1 + i2*nb2;
+ src0_1_row.data = src0_1_original + i02*nb02;
+ src0_2_row.data = src0_2_original + i02*nb02;
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
+ //dst_row.data = dst_original + i1*nb1 + i2*nb2;
- dst_row.data = dst_up_contiguous.get();
- ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
- CUDA_CHECK(cudaGetLastError());
+ dst_row.data = dst_up_contiguous.get();
+ ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row);
+ CUDA_CHECK(cudaGetLastError());
- dst_row.data = dst_gate_contiguous.get();
- ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
- CUDA_CHECK(cudaGetLastError());
+ dst_row.data = dst_gate_contiguous.get();
+ ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row);
+ CUDA_CHECK(cudaGetLastError());
- if (fuse_down) {
- ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
- (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
- CUDA_CHECK(cudaGetLastError());
+ if (fuse_down) {
+ ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
+ (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
+ CUDA_CHECK(cudaGetLastError());
- final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
- final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
- ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
- CUDA_CHECK(cudaGetLastError());
+ final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2];
+ final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2];
+ ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst);
+ CUDA_CHECK(cudaGetLastError());
- } else {
+ } else {
- ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
- (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
- CUDA_CHECK(cudaGetLastError());
+ ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0],
+ (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2));
+ CUDA_CHECK(cudaGetLastError());
- }
}
}
} else {
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu
index 2eafe463..576c387d 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cu
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cu
@@ -96,11 +96,13 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_mul_mat_vec_q(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
- const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) {
+ const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
int i2 = blockIdx.y;
- const char * cx = (const char *)vx + i2*nb02;
+ int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
+ if (i02 < 0) return;
+ const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
char * cdst = (char *)dst + i2*nb2;
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
@@ -108,9 +110,9 @@ __global__ void iqk_mul_mat_vec_q(
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
void iqk_mul_mat_vec_q_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
//GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
@@ -152,28 +154,28 @@ void iqk_mul_mat_vec_q_cuda(
switch (ncols_y) {
case 1:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 2:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 3:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 4:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 5:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 6:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 7:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
case 8:
- iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2);
+ iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
break;
default:
GGML_ASSERT(false);
@@ -742,79 +744,79 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
} // namespace
void mul_mat_vec_iq2_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_K, VDR_IQ2_K_Q8_1_MMVQ, vec_dot_iq2_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq3_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_K, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq4_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq4_ks_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq4_kss_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq2_ks_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_KS, VDR_IQ2_KS_Q8_1_MMVQ, vec_dot_iq2_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq5_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq6_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq1_bn_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
void mul_mat_vec_iq2_bn_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
- iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh
index a128bc53..15df20f5 100644
--- a/ggml/src/ggml-cuda/iqk_mmvq.cuh
+++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh
@@ -1,51 +1,51 @@
#include "common.cuh"
void mul_mat_vec_iq2_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq3_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq4_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq5_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq6_k_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq4_ks_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq4_kss_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq2_ks_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq1_bn_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
void mul_mat_vec_iq2_bn_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream);
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index e17e77a3..b9e9c216 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -134,11 +134,12 @@ template <ggml_type type, int ncols_y>
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst,
- const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) {
+ const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
int i2 = blockIdx.y;
- const char * cx = (const char *)vx + i2*nb02;
+ int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
+ const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
char * cdst = (char *)dst + i2*nb2;
mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
@@ -146,9 +147,9 @@ static __global__ void mul_mat_vec_q(
template <ggml_type type>
static void mul_mat_vec_q_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
@@ -188,28 +189,28 @@ static void mul_mat_vec_q_cuda(
switch (ncols_y) {
case 1:
- mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 2:
- mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 3:
- mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 4:
- mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 5:
- mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 6:
- mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 7:
- mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
case 8:
- mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2);
+ mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
break;
default:
GGML_ABORT("fatal error");
@@ -218,169 +219,169 @@ static void mul_mat_vec_q_cuda(
}
static void mul_mat_vec_q4_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q4_1_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q5_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q5_1_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q6_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q8_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q2_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q3_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q4_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q5_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_q6_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq2_xs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq2_s_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq1_s_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq1_m_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq4_xs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void mul_mat_vec_iq3_s_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
+ const void * vx, const void * vy, float * dst, const char * ids_data,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
- const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) {
+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
- mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
}
static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type,
const int64_t ne00, const int64_t ne0, const int64_t ne2,
- const int64_t nb02, const int64_t nb12, const int64_t nb2,
- const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i,
+ const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0,
+ const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data,
const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
@@ -391,98 +392,97 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
- const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size);
switch (type) {
case GGML_TYPE_Q4_0:
- mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q4_1:
- mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q5_0:
- mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q5_1:
- mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q6_0:
- mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q8_0:
- mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q2_K:
- mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q3_K:
- mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q4_K:
- mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q5_K:
- mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_Q6_K:
- mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_XXS:
- mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_XS:
- mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_S:
- mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ3_XXS:
- mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ1_S:
- mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ1_M:
- mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ1_BN:
- mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_BN:
- mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ4_NL:
- mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ4_XS:
- mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_K:
- mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ3_K:
- mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ4_K:
- mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ4_KS:
- mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ4_KSS:
- mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ2_KS:
- mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ5_K:
- mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ6_K:
- mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
case GGML_TYPE_IQ3_S:
- mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream);
+ mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
break;
default:
GGML_ABORT("fatal error");
@@ -505,14 +505,12 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
const int64_t ne0 = dst->ne[0];
- int id = ggml_cuda_get_device();
-
const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size);
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
ne00, ne0, dst->ne[2],
- src0->nb[2], src1_row_size, dst->nb[2],
- src0_dd_i, src1_ddq_i, dst_dd_i,
+ src0->nb[2], src1_row_size, dst->nb[2], 0,
+ src0_dd_i, src1_ddq_i, dst_dd_i, nullptr,
row_low, row_high, src1_ncols,
src1_padded_row_size, stream);
@@ -531,11 +529,35 @@ void ggml_cuda_op_mul_mat_vec_q(
const int64_t ne0 = dst->ne[0];
- int id = ggml_cuda_get_device();
+ ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
+ ne00, ne0, 1, 0, 0, 0, 0,
+ src0_dd_i, src1_ddq_i, dst_dd_i, nullptr,
+ row_low, row_high, src1_ncols,
+ src1_padded_row_size, stream);
+
+ GGML_UNUSED(src1_ddf_i);
+}
+
+void ggml_cuda_op_mul_mat_vec_q_id(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+ const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne10 = src1->ne[0];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+ GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1);
+ GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1);
+ GGML_ASSERT(ids->ne[0] == dst->ne[2]);
+
+ const int64_t ne0 = dst->ne[0];
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
- ne00, ne0, 1, 0, 0, 0,
- src0_dd_i, src1_ddq_i, dst_dd_i,
+ ne00, ne0, dst->ne[2],
+ src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0],
+ src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data,
row_low, row_high, src1_ncols,
src1_padded_row_size, stream);
diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh
index f86aefe2..c0699b04 100644
--- a/ggml/src/ggml-cuda/mmvq.cuh
+++ b/ggml/src/ggml-cuda/mmvq.cuh
@@ -11,3 +11,9 @@ void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
+
+void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+ const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/src/llama.cpp b/src/llama.cpp
index 7d665072..bad8d33d 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -13734,6 +13734,9 @@ struct llm_build_context {
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
+ if (kv_cache->ne[1] < 256) {
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ }
cb(kq, "kq", il);
if (!pp_opt) {