From 43f76bf1c362c067fce46bb8dcda0b64af8a9533 Mon Sep 17 00:00:00 2001 From: pudepiedj Date: Thu, 11 Jan 2024 16:14:52 +0000 Subject: main : print total token count and tokens consumed so far (#4874) * Token count changes * Add show token count * Updating before PR * Two requested changes * Move param def posn --- common/common.cpp | 8 ++++++++ common/common.h | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) (limited to 'common') diff --git a/common/common.cpp b/common/common.cpp index 4e89fe51..bfcd6d4d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -630,6 +630,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "-stc" || arg == "--show_token_count") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.token_interval = std::stoi(argv[i]); } else if (arg == "--ppl-output-type") { if (++i >= argc) { invalid_param = true; @@ -944,6 +950,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" -stc N --show_token_count N\n"); + printf(" show consumed tokens every N tokens\n"); printf("\n"); #ifndef LOG_DISABLE_LOGS log_print_usage(); diff --git a/common/common.h b/common/common.h index e2bbfc25..a295e88b 100644 --- a/common/common.h +++ b/common/common.h @@ -64,6 +64,7 @@ struct gpt_params { int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width + int32_t token_interval = 512; // show token count every 512 tokens float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor @@ -242,4 +243,3 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); - -- cgit v1.2.3