summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp8
-rw-r--r--common/common.h3
-rw-r--r--examples/llama-bench/llama-bench.cpp35
-rw-r--r--ggml/src/ggml-cuda/concat.cu30
-rw-r--r--ggml/src/ggml.c20
-rw-r--r--include/llama.h1
-rw-r--r--src/llama.cpp216
7 files changed, 235 insertions, 78 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 6359426f..5c9070da 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.mla_attn = std::stoi(argv[i]);
return true;
}
+ if (arg == "-amb" || arg == "--attention-max-batch") {
+ CHECK_ARG
+ params.attn_max_batch = std::stoi(argv[i]);
+ return true;
+ }
if (arg == "-fmoe" || arg == "--fused-moe") {
params.fused_moe_up_gate = true;
return true;
@@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
+ options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
@@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
+ cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
@@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
+ fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
diff --git a/common/common.h b/common/common.h
index ef5175f3..f35f3558 100644
--- a/common/common.h
+++ b/common/common.h
@@ -175,7 +175,8 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
- int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
+ int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
+ int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 5756843a..a08cb762 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -233,6 +233,7 @@ struct cmd_params {
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<int> mla_attn;
+ std::vector<int> attn_max_batch;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = {
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* mla_attn */ {0},
+ /* attn_max_batch */ {0},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -301,6 +303,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
+ printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -578,6 +581,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<int>(argv[i], split_delim);
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
+ } else if (arg == "-amb" || arg == "--attn-max-batch") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = string_split<int>(argv[i], split_delim);
+ params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
@@ -690,6 +700,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
+ if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -727,6 +738,7 @@ struct cmd_params_instance {
bool no_kv_offload;
bool flash_attn;
int mla_attn;
+ int attn_max_batch;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -773,6 +785,7 @@ struct cmd_params_instance {
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.mla_attn = mla_attn;
+ cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe;
cparams.embeddings = embeddings;
@@ -799,6 +812,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & mla : params.mla_attn)
+ for (const auto & amb : params.attn_max_batch)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
@@ -821,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
+ /* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -852,6 +867,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
+ /* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -883,6 +899,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
+ /* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -914,6 +931,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
+ /* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -956,6 +974,7 @@ struct test {
bool no_kv_offload;
bool flash_attn;
int mla_attn;
+ int attn_max_batch;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -987,6 +1006,7 @@ struct test {
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
mla_attn = inst.mla_attn;
+ attn_max_batch = inst.attn_max_batch;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
@@ -1081,7 +1101,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
- "main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
+ "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@@ -1097,7 +1117,7 @@ struct test {
field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
- field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
+ field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" ||
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
@@ -1138,7 +1158,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
- std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
+ std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1305,6 +1325,9 @@ struct markdown_printer : public printer {
if (field == "mla_attn") {
return 3;
}
+ if (field == "attn_max_batch") {
+ return 5;
+ }
if (field == "use_mmap") {
return 4;
}
@@ -1345,6 +1368,9 @@ struct markdown_printer : public printer {
if (field == "mla_attn") {
return "mla";
}
+ if (field == "attn_max_batch") {
+ return "amb";
+ }
if (field == "use_mmap") {
return "mmap";
}
@@ -1403,6 +1429,9 @@ struct markdown_printer : public printer {
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
fields.emplace_back("mla_attn");
}
+ if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
+ fields.emplace_back("attn_max_batch");
+ }
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
index dac10ec3..4bde6d69 100644
--- a/ggml/src/ggml-cuda/concat.cu
+++ b/ggml/src/ggml-cuda/concat.cu
@@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
- if (dim != 3) {
+ if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+ CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+ } else {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
@@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
- } else {
- const size_t size0 = ggml_nbytes(src0);
- const size_t size1 = ggml_nbytes(src1);
-
- CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
- CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
}
+
+ //if (dim != 3) {
+ // for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ // concat_f32_cuda(
+ // src0_d + i3 * (src0->nb[3] / 4),
+ // src1_d + i3 * (src1->nb[3] / 4),
+ // dst_d + i3 * ( dst->nb[3] / 4),
+ // src0->ne[0], src0->ne[1], src0->ne[2],
+ // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
+ // }
+ //} else {
+ // const size_t size0 = ggml_nbytes(src0);
+ // const size_t size1 = ggml_nbytes(src1);
+
+ // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+ //}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 80dd25ff..91c0c5db 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -12627,6 +12627,26 @@ static void ggml_compute_forward_concat_f32(
GGML_ASSERT(dim >= 0 && dim < 4);
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) &&
+ (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
+ // simply copy the data
+ const int64_t size_src_0 = ggml_nbytes(src0);
+ const int64_t size_src_1 = ggml_nbytes(src1);
+ const int64_t block_size = 4096;
+ const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size;
+ for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) {
+ const int64_t start = i_block*block_size;
+ if (start < size_src_0) {
+ int64_t copy_size = MIN(block_size, size_src_0 - start);
+ memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size);
+ } else {
+ int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start);
+ memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size);
+ }
+ }
+ return;
+ }
+
int64_t o[4] = {0, 0, 0, 0};
o[dim] = src0->ne[dim];
diff --git a/include/llama.h b/include/llama.h
index 2b33701c..bb43aebc 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -384,6 +384,7 @@ extern "C" {
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
+ int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]
// Abort callback
diff --git a/src/llama.cpp b/src/llama.cpp
index f2c5f9d4..0dcc78dc 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -2511,6 +2511,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
int mla_attn;
+ int attn_max_batch;
bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -8774,61 +8775,108 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
} else {
- struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
- cb(kq, "kq", il);
- //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
- // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
- // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
- ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- }
+ // split cached v into n_head heads
+ struct ggml_tensor * v =
+ ggml_view_3d(ctx, kv.v_l[il],
+ n_kv, n_embd_head_v, n_head_kv,
+ ggml_element_size(kv.v_l[il])*n_ctx,
+ ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
+ 0);
+ cb(v, "v", il);
- if (model.arch == LLM_ARCH_GROK) {
- // need to do the following:
- // multiply by attn_output_multiplyer of 0.08838834764831845
- // and then :
- // kq = 30 * tanh(kq / 30)
- // before the softmax below
+ auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
+ if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
+ struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
+ cb(kq, "kq", il);
- //try from phi2
//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
- //kq = ggml_scale(ctx, kq, 30);
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+ // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
+ // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
+ ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ }
- kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
- }
+ if (model.arch == LLM_ARCH_GROK) {
+ // need to do the following:
+ // multiply by attn_output_multiplyer of 0.08838834764831845
+ // and then :
+ // kq = 30 * tanh(kq / 30)
+ // before the softmax below
- if (hparams.attn_soft_cap) {
- //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
- kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
- 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
- } else {
- kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
- }
- cb(kq, "kq_soft_max_ext", il);
+ //try from phi2
+ //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
- GGML_ASSERT(kv.size == n_ctx);
+ //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
+ //kq = ggml_scale(ctx, kq, 30);
- // split cached v into n_head heads
- struct ggml_tensor * v =
- ggml_view_3d(ctx, kv.v_l[il],
- n_kv, n_embd_head_v, n_head_kv,
- ggml_element_size(kv.v_l[il])*n_ctx,
- ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
- 0);
- cb(v, "v", il);
+ kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f);
+ }
+
+ if (hparams.attn_soft_cap) {
+ //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ } else {
+ kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ }
+ cb(kq, "kq_soft_max_ext", il);
+
+ GGML_ASSERT(kv.size == n_ctx);
- struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
- cb(kqv, "kqv", il);
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
+ cb(kqv, "kqv", il);
- struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
- cb(kqv_merged, "kqv_merged", il);
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
- cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
- cb(cur, "kqv_merged_cont", il);
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur, "kqv_merged_cont", il);
+ }
+ else {
+ // For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2];
+ GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]);
+ int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch;
+ n_step = std::min(n_step, int(k->ne[2]));
+ int n_per_step = (q->ne[2] + n_step - 1)/n_step;
+ auto r2k = q->ne[2] / k->ne[2];
+ auto r2v = q->ne[2] / v->ne[2];
+ n_step = q->ne[2];
+ n_per_step = 1;
+ ggml_tensor * kqv;
+ for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) {
+ int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12;
+ int i02 = i12/r2k;
+ auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
+ auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
+ auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
+ ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
+ }
+ if (model.arch == LLM_ARCH_GROK) {
+ kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f);
+ }
+ if (hparams.attn_soft_cap) {
+ kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+ 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
+ } else {
+ kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ }
+ i02 = i12 / r2v;
+ auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02);
+ auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i);
+ if (i12 == 0) {
+ kqv = kqv_i;
+ } else {
+ kqv = ggml_concat(ctx, kqv, kqv_i, 2);
+ }
+ }
+ ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
+ cb(kqv_merged, "kqv_merged", il);
+ cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
+ cb(cur, "kqv_merged_cont", il);
+ }
}
ggml_build_forward_expand(graph, cur);
@@ -8924,6 +8972,7 @@ struct llm_build_context {
const bool flash_attn;
const int mla_attn;
+ const int attn_max_batch;
const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
@@ -8976,6 +9025,7 @@ struct llm_build_context {
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
mla_attn (cparams.mla_attn),
+ attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@@ -13572,25 +13622,6 @@ struct llm_build_context {
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
- if (!pp_opt) {
- q = ggml_permute(ctx0, q, 0, 2, 1, 3);
- cb(q, "q_perm", il);
- }
- ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
- cb(kq, "kq", il);
-
- if (!pp_opt) {
- kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
- cb(kq, "kq_perm", il);
- }
-
- kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
- cb(kq, "kq_soft_max_ext", il);
-
- if (!pp_opt) {
- kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
- cb(kq, "kq_soft_max_ext_perm", il);
- }
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
@@ -13602,12 +13633,60 @@ struct llm_build_context {
cb(kv_cache_trans, "kv_cache_trans", il);
}
- struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
- cb(kqv_compressed, "kqv_compressed", il);
+ ggml_tensor * kqv_compressed;
+ auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
+ if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
+ if (!pp_opt) {
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+ cb(q, "q_perm", il);
+ }
+
+ ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
+ cb(kq, "kq", il);
+
+ if (!pp_opt) {
+ kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
+ cb(kq, "kq_perm", il);
+ }
+
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ cb(kq, "kq_soft_max_ext", il);
+
+ if (!pp_opt) {
+ kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
+ cb(kq, "kq_soft_max_ext_perm", il);
+ }
+
+ kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
+ cb(kqv_compressed, "kqv_compressed", il);
+
+ if (!pp_opt) {
+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
+ cb(kqv_compressed, "kqv_compressed_perm", il);
+ }
+
+ } else {
+
+ int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
+ n_step = std::min(n_step, int(q->ne[2]));
+ int n_per_step = (q->ne[2] + n_step - 1)/n_step;
- if (!pp_opt) {
- kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
- cb(kqv_compressed, "kqv_compressed_perm", il);
+ //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
+
+ for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
+ int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
+ ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
+ ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
+ kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
+ ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
+ if (i_head == 0) {
+ kqv_compressed = kqv_i;
+ } else {
+ kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
+ }
+ ggml_build_forward_expand(gf, kqv_compressed);
+ }
+ cb(kqv_compressed, "kqv_compressed", il);
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
@@ -17644,6 +17723,7 @@ struct llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ 0,
+ /*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -17844,6 +17924,7 @@ struct llama_context * llama_new_context_with_model(
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
+ cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.pooling_type = params.pooling_type;
@@ -17912,6 +17993,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
+ LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);