diff options
author | slaren <slarengh@gmail.com> | 2024-04-26 18:39:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-26 18:39:58 +0200 |
commit | 017e6999b5184234370b22a2f868e1be911e8d88 (patch) | |
tree | 2a29b4d5bf7cfc6965ce895abee9e889b6529ade /common | |
parent | e2764cd7ca1112d9303eba9e81c9935ee67352ff (diff) |
add basic tensor data validation function (#6884)
* add basic tensor data validation function
* add --check-tensors command line argument
tensor validation is disabled by default and can be enabled by adding
`--check-tensors` to the command line arguments.
quantize always validates tensors.
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 6 | ||||
-rw-r--r-- | common/common.h | 1 |
2 files changed, 7 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 97f55b05..9ab8aef7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1089,6 +1089,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_print = std::stoi(argv[i]); return true; } + if (arg == "--check-tensors") { + params.check_tensors = true; + return true; + } if (arg == "--ppl-output-type") { if (++i >= argc) { invalid_param = true; @@ -1554,6 +1558,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -ptc N, --print-token-count N\n"); printf(" print token count every N tokens (default: %d)\n", params.n_print); + printf(" --check-tensors check model tensor data for invalid values\n"); printf("\n"); #ifndef LOG_DISABLE_LOGS log_print_usage(); @@ -1774,6 +1779,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; + mparams.check_tensors = params.check_tensors; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { diff --git a/common/common.h b/common/common.h index 87361e8e..0005f143 100644 --- a/common/common.h +++ b/common/common.h @@ -161,6 +161,7 @@ struct gpt_params { bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run + bool check_tensors = false; // validate tensor data std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V |