diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-12 07:21:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-12 07:21:46 +0200 |
commit | 3f23ed68f17583a8ee63afd0c214f5b39226226c (patch) | |
tree | ad86914fd2925935247d2fba0ebb3b8b5d2c9bfc | |
parent | a48e16324770bb829406d06e11be1df0c8a3b517 (diff) |
MLA-2: Allow usage of q8_0 for KV cache on CUDA (#252)
* FlashMLA(CUDA): WIP to allow q8_0 quantized cache
* WIP
* FlashMLA(CUDA) - allow q8_0 for KV cache
This works, and PP is not bad, but TG is still quite a bit slower.
* FlashMLA(CUDA) - allow q8_0 for KV cache
This is better. ~9% slower than f16 cache for short contexts,
nearly on par at 16k tokens.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda.cu | 11 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/cpy.cu | 137 |
2 files changed, 141 insertions, 7 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index f25dd725..1bb869c3 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2296,9 +2296,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - if (row_id_i != i02) { continue; } @@ -3458,6 +3455,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return true; } + if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) { + if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) { + return true; + } + } + if (ggml_are_same_shape(op->src[0], op->src[1]) && op->src[0]->type == GGML_TYPE_Q8_0 && op->src[1]->type == GGML_TYPE_Q8_0) { + return true; + } return false; } break; case GGML_OP_DUP: diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index fabe8843..44eba389 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,5 @@ #include "cpy.cuh" +#include "convert.cuh" typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -65,6 +66,66 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +//static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne, +// const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) { +// const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; +// +// if (i >= ne) { +// return; +// } +// +// const int64_t i03 = i/(ne00 * ne01 * ne02); +// const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01); +// const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; +// const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; +// +// const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); +// const int ib = i00/QK8_0; +// const int iq = i00%QK8_0; +// +// dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq]; +//} + +static __global__ void k_transpose_q8_0(const char * cx, char * cdst, + const int ne10, const int ne11, const int ne12, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + + //const int64_t ne00 = ne11; + //const int64_t ne01 = ne10; + //const int64_t ne02 = ne12; + const int64_t i03 = i13; + const int64_t i02 = i12; + const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; + const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; + + const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); + const int ib0 = i00/QK8_0; + const int iq0 = i00%QK8_0; + + float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0]; + float amax = fabsf(xi); + amax = warp_reduce_max(amax); + + //printf("%d, %d, %d: i = %ld, i11 = %ld i10 = %ld, xi = %g, amax = %g\n", blockDim.x, blockIdx.x, threadIdx.x, i, i11, i10, xi, amax); + + float d = amax/127; + int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13); + dst[i10 / QK8_0].qs[i10 % QK8_0] = q; + + if (threadIdx.x == 0) { + dst[i10 / QK8_0].d = __float2half(d); + } +} + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q8_0 * dsti = (block_q8_0 *) cdsti; @@ -464,6 +525,32 @@ static void ggml_cpy_f16_f16_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + auto stream = ctx.stream(); + auto num_blocks = ggml_nelements(dst)/QK8_0; + k_transpose_q8_0<<<num_blocks, QK8_0, 0, stream>>>( + (const char *)src->data, (char *)dst->data, + dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3], + dst->nb[1], dst->nb[2], dst->nb[3]); + + //auto ne = ggml_nelements(dst); + //ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne); + //const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + //auto aux_src = *dst; + //aux_src.nb[0] = sizeof(float); + //aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0]; + //aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1]; + //aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2]; + //cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> + // ((const char *)src->data, dst_f32.get(), ne, + // src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]); + //CUDA_CHECK(cudaGetLastError()); + //aux_src.type = GGML_TYPE_F32; + //ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2], + // aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3], + // dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream); +} + void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -522,9 +609,33 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) { + to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) { + to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) { + to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + transpose_q8_0(ctx, src0, src1); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); + fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, + src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); GGML_ABORT("fatal error"); } } @@ -559,9 +670,27 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16<cpy_1_f16_f16>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16<cpy_1_f16_f32>; - } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) return (void*)to_fp16; + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) return (void*)to_fp32; + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) return (void*)to_bf16; + } + } + else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + return (void *)transpose_q8_0; } + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, + src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); + GGML_ABORT("fatal error"); } |