diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-01 08:25:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-01 08:25:27 +0200 |
commit | a79ab8f34222e1e0142a30eaa97e78ad077abca9 (patch) | |
tree | 24f89079780736d697347e1ebbe6544750534e22 | |
parent | b762db7c9264199c2d0f66e7d63e3b4884f3fc0c (diff) |
Reduce size of compute buffers (#237)
* This reduces compute buffer size for MLA
* This should accomplish it for standard attention
* Much better
* Better concat for contiguous tensors
If all the op does is to concatenate the second tensor
to the first, why would we want to have a loop?
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 8 | ||||
-rw-r--r-- | common/common.h | 3 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 35 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/concat.cu | 30 | ||||
-rw-r--r-- | ggml/src/ggml.c | 20 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 216 |
7 files changed, 235 insertions, 78 deletions
diff --git a/common/common.cpp b/common/common.cpp index 6359426f..5c9070da 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mla_attn = std::stoi(argv[i]); return true; } + if (arg == "-amb" || arg == "--attention-max-batch") { + CHECK_ARG + params.attn_max_batch = std::stoi(argv[i]); + return true; + } if (arg == "-fmoe" || arg == "--fused-moe") { params.fused_moe_up_gate = true; return true; @@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); + options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); @@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); + fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index ef5175f3..f35f3558 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,8 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention - int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5756843a..a08cb762 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -233,6 +233,7 @@ struct cmd_params { std::vector<bool> no_kv_offload; std::vector<bool> flash_attn; std::vector<int> mla_attn; + std::vector<int> attn_max_batch; std::vector<std::vector<float>> tensor_split; std::vector<bool> use_mmap; std::vector<bool> embeddings; @@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = { /* no_kv_offload */ {false}, /* flash_attn */ {false}, /* mla_attn */ {0}, + /* attn_max_batch */ {0}, /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -301,6 +303,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); + printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa <distribute|isolate|numactl> (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -578,6 +581,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split<int>(argv[i], split_delim); params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); + } else if (arg == "-amb" || arg == "--attn-max-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split<int>(argv[i], split_delim); + params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -690,6 +700,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } + if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -727,6 +738,7 @@ struct cmd_params_instance { bool no_kv_offload; bool flash_attn; int mla_attn; + int attn_max_batch; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -773,6 +785,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; + cparams.attn_max_batch = attn_max_batch; cparams.fused_moe_up_gate = fmoe; cparams.embeddings = embeddings; @@ -799,6 +812,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) + for (const auto & amb : params.attn_max_batch) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -821,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -852,6 +867,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -883,6 +899,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -914,6 +931,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -956,6 +974,7 @@ struct test { bool no_kv_offload; bool flash_attn; int mla_attn; + int attn_max_batch; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -987,6 +1006,7 @@ struct test { no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; + attn_max_batch = inst.attn_max_batch; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -1081,7 +1101,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1097,7 +1117,7 @@ struct test { field == "n_threads" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || field == "mla_attn" || + field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" || field == "avg_ns" || field == "stddev_ns") { return INT; } @@ -1138,7 +1158,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1305,6 +1325,9 @@ struct markdown_printer : public printer { if (field == "mla_attn") { return 3; } + if (field == "attn_max_batch") { + return 5; + } if (field == "use_mmap") { return 4; } @@ -1345,6 +1368,9 @@ struct markdown_printer : public printer { if (field == "mla_attn") { return "mla"; } + if (field == "attn_max_batch") { + return "amb"; + } if (field == "use_mmap") { return "mmap"; } @@ -1403,6 +1429,9 @@ struct markdown_printer : public printer { if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { fields.emplace_back("mla_attn"); } + if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { + fields.emplace_back("attn_max_batch"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index dac10ec3..4bde6d69 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; - if (dim != 3) { + if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + } else { for (int i3 = 0; i3 < dst->ne[3]; i3++) { concat_f32_cuda( src0_d + i3 * (src0->nb[3] / 4), @@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } - } else { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); } + + //if (dim != 3) { + // for (int i3 = 0; i3 < dst->ne[3]; i3++) { + // concat_f32_cuda( + // src0_d + i3 * (src0->nb[3] / 4), + // src1_d + i3 * (src1->nb[3] / 4), + // dst_d + i3 * ( dst->nb[3] / 4), + // src0->ne[0], src0->ne[1], src0->ne[2], + // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + // } + //} else { + // const size_t size0 = ggml_nbytes(src0); + // const size_t size1 = ggml_nbytes(src1); + + // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + //} } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 80dd25ff..91c0c5db 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12627,6 +12627,26 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(dim >= 0 && dim < 4); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) && + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + // simply copy the data + const int64_t size_src_0 = ggml_nbytes(src0); + const int64_t size_src_1 = ggml_nbytes(src1); + const int64_t block_size = 4096; + const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size; + for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) { + const int64_t start = i_block*block_size; + if (start < size_src_0) { + int64_t copy_size = MIN(block_size, size_src_0 - start); + memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size); + } else { + int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start); + memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size); + } + } + return; + } + int64_t o[4] = {0, 0, 0, 0}; o[dim] = src0->ne[dim]; diff --git a/include/llama.h b/include/llama.h index 2b33701c..bb43aebc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -384,6 +384,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] int mla_attn; // whether to use MLA attention [EXPERIMENTAL] + int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback diff --git a/src/llama.cpp b/src/llama.cpp index f2c5f9d4..0dcc78dc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2511,6 +2511,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; int mla_attn; + int attn_max_batch; bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8774,61 +8775,108 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); + if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - //try from phi2 //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - //kq = ggml_scale(ctx, kq, 30); + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } - kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); - } + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below - if (hparams.attn_soft_cap) { - //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); - kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias, - 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - } - cb(kq, "kq_soft_max_ext", il); + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - GGML_ASSERT(kv.size == n_ctx); + //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + //kq = ggml_scale(ctx, kq, 30); - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); + kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); + } + + if (hparams.attn_soft_cap) { + //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias, + 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + } + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv.size == n_ctx); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } + else { + // For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]; + GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]); + int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch; + n_step = std::min(n_step, int(k->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + auto r2k = q->ne[2] / k->ne[2]; + auto r2v = q->ne[2] / v->ne[2]; + n_step = q->ne[2]; + n_per_step = 1; + ggml_tensor * kqv; + for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) { + int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12; + int i02 = i12/r2k; + auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02); + auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); + auto kq_i = ggml_mul_mat(ctx, k_i, q_i); + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32); + } + if (model.arch == LLM_ARCH_GROK) { + kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f); + } + if (hparams.attn_soft_cap) { + kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias, + 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + } else { + kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias); + } + i02 = i12 / r2v; + auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02); + auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i); + if (i12 == 0) { + kqv = kqv_i; + } else { + kqv = ggml_concat(ctx, kqv, kqv_i, 2); + } + } + ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } } ggml_build_forward_expand(graph, cur); @@ -8924,6 +8972,7 @@ struct llm_build_context { const bool flash_attn; const int mla_attn; + const int attn_max_batch; const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; @@ -8976,6 +9025,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -13572,25 +13622,6 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (!pp_opt) { - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_perm", il); - } - ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); - cb(kq, "kq", il); - - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); - } - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - if (!pp_opt) { - kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq, "kq_soft_max_ext_perm", il); - } if (lctx.cparams.mla_attn > 1) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], @@ -13602,12 +13633,60 @@ struct llm_build_context { cb(kv_cache_trans, "kv_cache_trans", il); } - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); - cb(kqv_compressed, "kqv_compressed", il); + ggml_tensor * kqv_compressed; + auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); + } + + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + cb(kq, "kq", il); + + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + } else { + + int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + n_step = std::min(n_step, int(q->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; - if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); + + for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { + int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; + ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = kqv_i; + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); + } + cb(kqv_compressed, "kqv_compressed", il); } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, @@ -17644,6 +17723,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ 0, + /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -17844,6 +17924,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; @@ -17912,6 +17993,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |