diff options
-rw-r--r-- | common/common.cpp | 7 | ||||
-rw-r--r-- | common/common.h | 2 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 16 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 26 | ||||
-rw-r--r-- | include/llama.h | 2 | ||||
-rw-r--r-- | src/llama.cpp | 102 |
6 files changed, 64 insertions, 91 deletions
diff --git a/common/common.cpp b/common/common.cpp index 464b4710..6359426f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -851,7 +851,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-mla" || arg == "--mla-use") { - params.mla_attn = true; + CHECK_ARG + params.mla_attn = std::stoi(argv[i]); return true; } if (arg == "-fmoe" || arg == "--fused-moe") { @@ -1514,7 +1515,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); - options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -3357,7 +3358,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); + fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 152fd1cf..ef5175f3 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention - bool mla_attn = false; // MLA + int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 438d2a7c..5756843a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -232,7 +232,7 @@ struct cmd_params { std::vector<int> main_gpu; std::vector<bool> no_kv_offload; std::vector<bool> flash_attn; - std::vector<bool> mla_attn; + std::vector<int> mla_attn; std::vector<std::vector<float>> tensor_split; std::vector<bool> use_mmap; std::vector<bool> embeddings; @@ -264,7 +264,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, - /* mla_attn */ {false}, + /* mla_attn */ {0}, /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -300,7 +300,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); - printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); + printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa <distribute|isolate|numactl> (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -576,7 +576,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = true; break; } - auto p = string_split<bool>(argv[i], split_delim); + auto p = string_split<int>(argv[i], split_delim); params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { @@ -726,7 +726,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; - bool mla_attn; + int mla_attn; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -955,7 +955,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; - bool mla_attn; + int mla_attn; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -1097,13 +1097,13 @@ struct test { field == "n_threads" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || + field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "avg_ns" || field == "stddev_ns") { return INT; } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "fused_moe") { return BOOL; } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 2a520d68..e17e77a3 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -158,11 +158,6 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - //if (ne2 > 1) { - // printf("%s: ncols_x = %d, nrows_x = %d, nrows_y = %d, ncols_y = %d nrows_dst = %d, ne2 = %d nb02 = %zu, nb12 = %zu, nb2 = %zu\n", __func__, - // ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2); - //} - if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: @@ -382,9 +377,8 @@ static void mul_mat_vec_iq3_s_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } -namespace { -void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, - const int64_t ne00, const int64_t ne10, const int64_t ne0, const int64_t ne2, +static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, + const int64_t ne00, const int64_t ne0, const int64_t ne2, const int64_t nb02, const int64_t nb12, const int64_t nb2, const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, @@ -496,7 +490,6 @@ void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type } } -} void ggml_cuda_op_mul_mat_vec_q_3D( ggml_backend_cuda_context & ctx, @@ -505,8 +498,6 @@ void ggml_cuda_op_mul_mat_vec_q_3D( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); @@ -516,13 +507,10 @@ void ggml_cuda_op_mul_mat_vec_q_3D( int id = ggml_cuda_get_device(); - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, - ne00, ne10, ne0, dst->ne[2], + ne00, ne0, dst->ne[2], src0->nb[2], src1_row_size, dst->nb[2], src0_dd_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols, @@ -538,8 +526,6 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); @@ -547,12 +533,8 @@ void ggml_cuda_op_mul_mat_vec_q( int id = ggml_cuda_get_device(); - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, - ne00, ne10, ne0, 1, 0, 0, 0, + ne00, ne0, 1, 0, 0, 0, src0_dd_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols, src1_padded_row_size, stream); diff --git a/include/llama.h b/include/llama.h index beb6ecba..2b33701c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -383,7 +383,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] + int mla_attn; // whether to use MLA attention [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback diff --git a/src/llama.cpp b/src/llama.cpp index ebc7a772..f2c5f9d4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -112,14 +112,6 @@ #define LLAMA_MAX_EXPERTS 256 // DeepSeekV2 // -// === MLA cache -// If tou are desperate to reduce KV cache size, set MLA_USE_TRANSPOSED_CACHE to 0. -// TG perfornce will be slower (similar to no-MLA), but KV cache size will be cut to ~half. -// PP performance will be about the same as with MLA_USE_TRANSPOSED_CACHE = 1. -// -#define MLA_USE_TRANSPOSED_CACHE 1 - -// // helpers // @@ -2518,7 +2510,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; - bool mla_attn; + int mla_attn; bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -2695,9 +2687,7 @@ struct llama_kv_cache { // DeepSeek MLA std::vector<struct ggml_tensor *> kv_l; -#if MLA_USE_TRANSPOSED_CACHE std::vector<struct ggml_tensor *> kvt_l; -#endif std::vector<struct ggml_context *> ctxs; std::vector<ggml_backend_buffer_t> bufs; @@ -3175,9 +3165,9 @@ static bool llama_kv_cache_init( // DeepSeek MLA cache.kv_l.reserve(n_layer); -#if MLA_USE_TRANSPOSED_CACHE - cache.kvt_l.reserve(n_layer); -#endif + if (cparams.mla_attn == 1) { + cache.kvt_l.reserve(n_layer); + } bool warn = true; int n_mla = 0; @@ -3208,25 +3198,18 @@ static bool llama_kv_cache_init( const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); -#if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); - //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); -#else - ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_v, kv_lora_rank + n_embd_head_qk_rope, kv_size); -#endif + auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; + ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); -#if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kvt_l.push_back(kvt); -#endif + if (cparams.mla_attn == 1) { + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kvt_l.push_back(kvt); + } n_mla++; } else { - //printf("Creating cache tensors:\n"); - //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k); - //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); @@ -8940,7 +8923,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; - const bool mla_attn; + const int mla_attn; const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; @@ -13546,20 +13529,22 @@ struct llm_build_context { if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { -#if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, - ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); - cb(kv_cache_trans_view, "kv_cache_trans_view", il); + ggml_tensor * kv_cache_trans; - // note: storing transposed c^KV in the transposed KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); + if (lctx.cparams.mla_attn == 1) { + ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); - ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], - n_kv, kv_lora_rank, - ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), - 0); - cb(kv_cache_trans, "kv_cache_trans", il); -#endif + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); + + kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); + } ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); cb(kvr, "kvr", il); @@ -13607,15 +13592,15 @@ struct llm_build_context { cb(kq, "kq_soft_max_ext_perm", il); } -#if !MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], - kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_cache, "kv_cache_lora", il); + if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); - ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); - cb(kv_cache_trans, "kv_cache_trans", il); -#endif + kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); + } struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); cb(kqv_compressed, "kqv_compressed", il); @@ -17658,7 +17643,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, - /*.mla_attn =*/ false, + /*.mla_attn =*/ 0, /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -18140,18 +18125,23 @@ struct llama_context * llama_new_context_with_model( kv_type = kv->type; } -#if MLA_USE_TRANSPOSED_CACHE for (auto & kvt : ctx->kv_self.kvt_l) { memory_size_kvt += ggml_nbytes(kvt); kvt_type = kvt->type; } -#endif if (memory_size_kv + memory_size_kvt > 0) { - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, - (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), - ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); + if (cparams.mla_attn == 1) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), + ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); + } else { + GGML_ASSERT(memory_size_kvt == 0); + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f)); + } } } |