From a366a3d17d8f2de0eb8c3d9eddc7b5840fb5761a Mon Sep 17 00:00:00 2001 From: saood06 Date: Mon, 10 Feb 2025 09:40:38 -0600 Subject: Load all MoE experts during warmup and make warmup 1 token (#198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Load all MoE experts during warmup Co-authored-by: Stanisław Szymczyk * Unify warmup to one token --------- Co-authored-by: Stanisław Szymczyk --- src/llama.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'src/llama.cpp') diff --git a/src/llama.cpp b/src/llama.cpp index 00e6c934..b2553802 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3784,7 +3784,7 @@ static size_t llama_model_max_nodes(const llama_model & /*model*/) { // return 32768; //} - return 8192; + return 65536; } struct llama_model_loader { @@ -8879,7 +8879,8 @@ struct llm_build_context { llama_context & lctx, const llama_batch & batch, const llm_build_cb & cb, - bool worst_case) : + bool worst_case, + bool warmup) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -8897,7 +8898,7 @@ struct llm_build_context { n_embd_head_v (hparams.n_embd_head_v), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), + n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -14433,7 +14434,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14450,7 +14451,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14467,7 +14468,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14517,7 +14518,11 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, batch, cb, worst_case); + const llama_vocab * vocab = llama_get_vocab(&lctx); + llama_token bos = llama_token_bos_impl(*vocab); + llama_token eos = llama_token_eos_impl(*vocab); + bool is_warming_up = (batch.n_tokens == 1 && batch.token[0] == bos); + struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up); llm.init(); -- cgit v1.2.3