diff options
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r-- | examples/main/main.cpp | 83 |
1 files changed, 61 insertions, 22 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c096f110..5ea67051 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -439,6 +439,21 @@ int main(int argc, char ** argv) { LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + + // group-attention state + // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) + int ga_i = 0; + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT + LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); + } LOG_TEE("\n\n"); if (params.interactive) { @@ -500,37 +515,61 @@ int main(int argc, char ** argv) { fflush(stdout); } - // infinite text generation via context swapping - // if we run out of context: - // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { - if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); - break; - } + if (ga_n == 1) { + // infinite text generation via context shifting + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - n_past -= n_discard; + n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); } + } else { + // context extension via Self-Extend + while (n_past >= ga_i + ga_w) { + const int ib = (ga_n*ga_i)/ga_w; + const int bd = (ga_w/ga_n)*(ga_n - 1); + const int dd = (ga_w/ga_n) - ib*bd - ga_w; - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("\n"); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); + LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); + LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); - LOG("clear session path\n"); - path_session.clear(); + n_past -= bd; + + ga_i += ga_w/ga_n; + + LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); + } } // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) |