diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-05-09 10:22:48 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-09 10:22:48 +0300 |
commit | 8777fc4855dd1551c20a84cb266f75fa49e9b0e8 (patch) | |
tree | d67ab3b4c004b5043452928147cb3392daa4a828 | |
parent | 496451a1d4c41300ebdb102f12401b8ffa5b1d4b (diff) |
Fix CUDA FlashMLA-3 with quantized KV cache (#400)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-cuda/fattn-new-mma.cu | 54 |
1 files changed, 37 insertions, 17 deletions
diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 796d9c7b..d1484451 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1362,26 +1362,46 @@ void launch_fattn_new_mma( to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); K_data = (char *) K_f16.ptr; - const size_t bs = ggml_blck_size(K->type); - const size_t ts = ggml_type_size(K->type); - - nb11 = nb11*bs*sizeof(half)/ts; - nb12 = nb12*bs*sizeof(half)/ts; - nb13 = nb13*bs*sizeof(half)/ts; + nb11 = K->ne[0]*sizeof(half); + nb12 = nb11*K->ne[1]; + nb13 = nb12*K->ne[2]; + + // Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are + // gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory. + //const size_t bs = ggml_blck_size(K->type); + //const size_t ts = ggml_type_size(K->type); + + //nb11 = nb11*bs*sizeof(half)/ts; + //nb12 = nb12*bs*sizeof(half)/ts; + //nb13 = nb13*bs*sizeof(half)/ts; } if (need_f16_V && V->type != GGML_TYPE_F16) { - V_f16.alloc(ggml_nelements(V)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + if constexpr (DV == 512) { + // DeepSeek. In this case the V cache is the same as the K cache, except that + // it has 512 elements per row instead of 576. + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + V_data = K_data; + } else { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = K->ne[0]*sizeof(half); + nb22 = nb21*V->ne[1]; + nb23 = nb22*V->ne[2]; + + // Original PR in llama.cpp. Same comment as above for the K cache. + //const size_t bs = ggml_blck_size(V->type); + //const size_t ts = ggml_type_size(V->type); + + //nb21 = nb21*bs*sizeof(half)/ts; + //nb22 = nb22*bs*sizeof(half)/ts; + //nb23 = nb23*bs*sizeof(half)/ts; + } } int parallel_blocks = 1; |