summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp7
-rw-r--r--common/common.h2
-rw-r--r--examples/llama-bench/llama-bench.cpp16
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu26
-rw-r--r--include/llama.h2
-rw-r--r--src/llama.cpp102
6 files changed, 64 insertions, 91 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 464b4710..6359426f 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -851,7 +851,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-mla" || arg == "--mla-use") {
- params.mla_attn = true;
+ CHECK_ARG
+ params.mla_attn = std::stoi(argv[i]);
return true;
}
if (arg == "-fmoe" || arg == "--fused-moe") {
@@ -1514,7 +1515,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
- options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" });
+ options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
@@ -3357,7 +3358,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
- fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false");
+ fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
diff --git a/common/common.h b/common/common.h
index 152fd1cf..ef5175f3 100644
--- a/common/common.h
+++ b/common/common.h
@@ -175,7 +175,7 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
- bool mla_attn = false; // MLA
+ int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 438d2a7c..5756843a 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -232,7 +232,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
- std::vector<bool> mla_attn;
+ std::vector<int> mla_attn;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -264,7 +264,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
- /* mla_attn */ {false},
+ /* mla_attn */ {0},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -300,7 +300,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
- printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
+ printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -576,7 +576,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true;
break;
}
- auto p = string_split<bool>(argv[i], split_delim);
+ auto p = string_split<int>(argv[i], split_delim);
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
@@ -726,7 +726,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
- bool mla_attn;
+ int mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -955,7 +955,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
- bool mla_attn;
+ int mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -1097,13 +1097,13 @@ struct test {
field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
- field == "n_prompt" || field == "n_gen" ||
+ field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
- field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" ||
+ field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" ||
field == "fused_moe") {
return BOOL;
}
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 2a520d68..e17e77a3 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -158,11 +158,6 @@ static void mul_mat_vec_q_cuda(
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
- //if (ne2 > 1) {
- // printf("%s: ncols_x = %d, nrows_x = %d, nrows_y = %d, ncols_y = %d nrows_dst = %d, ne2 = %d nb02 = %zu, nb12 = %zu, nb2 = %zu\n", __func__,
- // ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2);
- //}
-
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
@@ -382,9 +377,8 @@ static void mul_mat_vec_iq3_s_q8_1_cuda(
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream);
}
-namespace {
-void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type,
- const int64_t ne00, const int64_t ne10, const int64_t ne0, const int64_t ne2,
+static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type,
+ const int64_t ne00, const int64_t ne0, const int64_t ne2,
const int64_t nb02, const int64_t nb12, const int64_t nb2,
const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i,
const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
@@ -496,7 +490,6 @@ void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type
}
}
-}
void ggml_cuda_op_mul_mat_vec_q_3D(
ggml_backend_cuda_context & ctx,
@@ -505,8 +498,6 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
- const int64_t row_diff = row_high - row_low;
-
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne10 % QK8_1 == 0);
GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1);
@@ -516,13 +507,10 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
int id = ggml_cuda_get_device();
- // the main device has a larger memory buffer to hold the results from all GPUs
- // nrows_dst == nrows of the matrix that the kernel writes into
- const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size);
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
- ne00, ne10, ne0, dst->ne[2],
+ ne00, ne0, dst->ne[2],
src0->nb[2], src1_row_size, dst->nb[2],
src0_dd_i, src1_ddq_i, dst_dd_i,
row_low, row_high, src1_ncols,
@@ -538,8 +526,6 @@ void ggml_cuda_op_mul_mat_vec_q(
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
- const int64_t row_diff = row_high - row_low;
-
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne10 % QK8_1 == 0);
@@ -547,12 +533,8 @@ void ggml_cuda_op_mul_mat_vec_q(
int id = ggml_cuda_get_device();
- // the main device has a larger memory buffer to hold the results from all GPUs
- // nrows_dst == nrows of the matrix that the kernel writes into
- const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
-
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
- ne00, ne10, ne0, 1, 0, 0, 0,
+ ne00, ne0, 1, 0, 0, 0,
src0_dd_i, src1_ddq_i, dst_dd_i,
row_low, row_high, src1_ncols,
src1_padded_row_size, stream);
diff --git a/include/llama.h b/include/llama.h
index beb6ecba..2b33701c 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -383,7 +383,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
- bool mla_attn; // whether to use MLA attention [EXPERIMENTAL]
+ int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]
// Abort callback
diff --git a/src/llama.cpp b/src/llama.cpp
index ebc7a772..f2c5f9d4 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -112,14 +112,6 @@
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2
//
-// === MLA cache
-// If tou are desperate to reduce KV cache size, set MLA_USE_TRANSPOSED_CACHE to 0.
-// TG perfornce will be slower (similar to no-MLA), but KV cache size will be cut to ~half.
-// PP performance will be about the same as with MLA_USE_TRANSPOSED_CACHE = 1.
-//
-#define MLA_USE_TRANSPOSED_CACHE 1
-
-//
// helpers
//
@@ -2518,7 +2510,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
- bool mla_attn;
+ int mla_attn;
bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -2695,9 +2687,7 @@ struct llama_kv_cache {
// DeepSeek MLA
std::vector<struct ggml_tensor *> kv_l;
-#if MLA_USE_TRANSPOSED_CACHE
std::vector<struct ggml_tensor *> kvt_l;
-#endif
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@@ -3175,9 +3165,9 @@ static bool llama_kv_cache_init(
// DeepSeek MLA
cache.kv_l.reserve(n_layer);
-#if MLA_USE_TRANSPOSED_CACHE
- cache.kvt_l.reserve(n_layer);
-#endif
+ if (cparams.mla_attn == 1) {
+ cache.kvt_l.reserve(n_layer);
+ }
bool warn = true;
int n_mla = 0;
@@ -3208,25 +3198,18 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
-#if MLA_USE_TRANSPOSED_CACHE
- ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
- //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size);
-#else
- ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_v, kv_lora_rank + n_embd_head_qk_rope, kv_size);
-#endif
+ auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
+ ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kv_l.push_back(kv);
-#if MLA_USE_TRANSPOSED_CACHE
- ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
- ggml_format_name(kvt, "cache_kvt_l%d", i);
- cache.kvt_l.push_back(kvt);
-#endif
+ if (cparams.mla_attn == 1) {
+ ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size);
+ ggml_format_name(kvt, "cache_kvt_l%d", i);
+ cache.kvt_l.push_back(kvt);
+ }
n_mla++;
}
else {
- //printf("Creating cache tensors:\n");
- //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k);
- //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
@@ -8940,7 +8923,7 @@ struct llm_build_context {
const int32_t n_ctx_orig;
const bool flash_attn;
- const bool mla_attn;
+ const int mla_attn;
const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
@@ -13546,20 +13529,22 @@ struct llm_build_context {
if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) {
-#if MLA_USE_TRANSPOSED_CACHE
- ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
- ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
- cb(kv_cache_trans_view, "kv_cache_trans_view", il);
+ ggml_tensor * kv_cache_trans;
- // note: storing transposed c^KV in the transposed KV cache
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
+ if (lctx.cparams.mla_attn == 1) {
+ ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank,
+ ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head));
+ cb(kv_cache_trans_view, "kv_cache_trans_view", il);
- ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
- n_kv, kv_lora_rank,
- ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size),
- 0);
- cb(kv_cache_trans, "kv_cache_trans", il);
-#endif
+ // note: storing transposed c^KV in the transposed KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
+
+ kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il],
+ n_kv, kv_lora_rank,
+ ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size),
+ 0);
+ cb(kv_cache_trans, "kv_cache_trans", il);
+ }
ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0);
cb(kvr, "kvr", il);
@@ -13607,15 +13592,15 @@ struct llm_build_context {
cb(kq, "kq_soft_max_ext_perm", il);
}
-#if !MLA_USE_TRANSPOSED_CACHE
- ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
- kv_lora_rank, n_kv,
- ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
- cb(kv_cache, "kv_cache_lora", il);
+ if (lctx.cparams.mla_attn > 1) {
+ ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
+ kv_lora_rank, n_kv,
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+ cb(kv_cache, "kv_cache_lora", il);
- ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
- cb(kv_cache_trans, "kv_cache_trans", il);
-#endif
+ kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
+ cb(kv_cache_trans, "kv_cache_trans", il);
+ }
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
@@ -17658,7 +17643,7 @@ struct llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
- /*.mla_attn =*/ false,
+ /*.mla_attn =*/ 0,
/*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -18140,18 +18125,23 @@ struct llama_context * llama_new_context_with_model(
kv_type = kv->type;
}
-#if MLA_USE_TRANSPOSED_CACHE
for (auto & kvt : ctx->kv_self.kvt_l) {
memory_size_kvt += ggml_nbytes(kvt);
kvt_type = kvt->type;
}
-#endif
if (memory_size_kv + memory_size_kvt > 0) {
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
- (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
- ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
- ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
+ if (cparams.mla_attn == 1) {
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
+ (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
+ ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
+ ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
+ } else {
+ GGML_ASSERT(memory_size_kvt == 0);
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__,
+ (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
+ ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f));
+ }
}
}