diff options
-rw-r--r-- | common/common.cpp | 8 | ||||
-rw-r--r-- | common/common.h | 1 | ||||
-rw-r--r-- | examples/imatrix/imatrix.cpp | 3 |
3 files changed, 11 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp index 64f160af..8eb23ade 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1599,6 +1599,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.process_output = true; return true; } + if (arg == "--output-tensor-name") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.output_tensor_name = argv[i]; + return true; + } if (arg == "--no-ppl") { params.compute_ppl = false; return true; diff --git a/common/common.h b/common/common.h index 9a1dc4a2..bb45b3b4 100644 --- a/common/common.h +++ b/common/common.h @@ -224,6 +224,7 @@ struct gpt_params { // imatrix params std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file + std::string output_tensor_name = "output.weight"; // name of the output tensor int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 574f5ed9..c7d73cdb 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -83,7 +83,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (t->op != GGML_OP_MUL_MAT) return false; // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; - if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false; + //printf("wname = %s\n", wname.c_str()); + if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == m_params.output_tensor_name))) return false; return true; } |