summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsaood06 <saood05@gmail.com>2025-02-10 09:40:38 -0600
committerGitHub <noreply@github.com>2025-02-10 17:40:38 +0200
commita366a3d17d8f2de0eb8c3d9eddc7b5840fb5761a (patch)
treeae0cb943fb4b83cb9e24d1a51d15550d5d7f0903
parentc12f73ba6153d162f36434cb48e36dd3649b7701 (diff)
Load all MoE experts during warmup and make warmup 1 token (#198)
* Load all MoE experts during warmup Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> * Unify warmup to one token --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
-rw-r--r--common/common.cpp6
-rw-r--r--examples/llama-bench/llama-bench.cpp2
-rw-r--r--src/llama.cpp19
3 files changed, 17 insertions, 10 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 6219f0ce..44678d7a 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2169,8 +2169,10 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
if (bos != -1) {
tmp.push_back(bos);
}
- tmp.push_back(eos);
-
+ else
+ {
+ tmp.push_back(eos);
+ }
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 41b93df5..95df06dc 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -1586,7 +1586,7 @@ int main(int argc, char ** argv) {
if (params.warmup) {
if (t.n_prompt > 0) {
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
- test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
+ test_prompt(ctx, 1, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
test_gen(ctx, 1, 0, t.n_threads);
diff --git a/src/llama.cpp b/src/llama.cpp
index 00e6c934..b2553802 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -3784,7 +3784,7 @@ static size_t llama_model_max_nodes(const llama_model & /*model*/) {
// return 32768;
//}
- return 8192;
+ return 65536;
}
struct llama_model_loader {
@@ -8879,7 +8879,8 @@ struct llm_build_context {
llama_context & lctx,
const llama_batch & batch,
const llm_build_cb & cb,
- bool worst_case) :
+ bool worst_case,
+ bool warmup) :
model (lctx.model),
lctx (lctx),
hparams (model.hparams),
@@ -8897,7 +8898,7 @@ struct llm_build_context {
n_embd_head_v (hparams.n_embd_head_v),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert),
- n_expert_used (hparams.n_expert_used),
+ n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used),
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
@@ -14433,7 +14434,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
- struct llm_build_context llm(lctx, dummy, cb, false);
+ struct llm_build_context llm(lctx, dummy, cb, false, false);
llm.init();
@@ -14450,7 +14451,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
- struct llm_build_context llm(lctx, dummy, cb, false);
+ struct llm_build_context llm(lctx, dummy, cb, false, false);
llm.init();
@@ -14467,7 +14468,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
- struct llm_build_context llm(lctx, dummy, cb, false);
+ struct llm_build_context llm(lctx, dummy, cb, false, false);
llm.init();
@@ -14517,7 +14518,11 @@ static struct ggml_cgraph * llama_build_graph(
struct ggml_cgraph * result = NULL;
- struct llm_build_context llm(lctx, batch, cb, worst_case);
+ const llama_vocab * vocab = llama_get_vocab(&lctx);
+ llama_token bos = llama_token_bos_impl(*vocab);
+ llama_token eos = llama_token_eos_impl(*vocab);
+ bool is_warming_up = (batch.n_tokens == 1 && batch.token[0] == bos);
+ struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up);
llm.init();