diff options
author | pudepiedj <pudepiedj@gmail.com> | 2024-01-11 16:14:52 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-11 18:14:52 +0200 |
commit | 43f76bf1c362c067fce46bb8dcda0b64af8a9533 (patch) | |
tree | c446cabb97601363b645f2fb661bc48d790478fa /common | |
parent | 2f043328e3116724d15b915b5c6078e2df860a69 (diff) |
main : print total token count and tokens consumed so far (#4874)
* Token count changes
* Add show token count
* Updating before PR
* Two requested changes
* Move param def posn
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 8 | ||||
-rw-r--r-- | common/common.h | 2 |
2 files changed, 9 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp index 4e89fe51..bfcd6d4d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -630,6 +630,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "-stc" || arg == "--show_token_count") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.token_interval = std::stoi(argv[i]); } else if (arg == "--ppl-output-type") { if (++i >= argc) { invalid_param = true; @@ -944,6 +950,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" -stc N --show_token_count N\n"); + printf(" show consumed tokens every N tokens\n"); printf("\n"); #ifndef LOG_DISABLE_LOGS log_print_usage(); diff --git a/common/common.h b/common/common.h index e2bbfc25..a295e88b 100644 --- a/common/common.h +++ b/common/common.h @@ -64,6 +64,7 @@ struct gpt_params { int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width + int32_t token_interval = 512; // show token count every 512 tokens float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor @@ -242,4 +243,3 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); - |