summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp8
-rw-r--r--common/common.h1
2 files changed, 9 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 64f160af..8eb23ade 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1599,6 +1599,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.process_output = true;
return true;
}
+ if (arg == "--output-tensor-name") {
+ if (++i >= argc) {
+ invalid_param = true;
+ return true;
+ }
+ params.output_tensor_name = argv[i];
+ return true;
+ }
if (arg == "--no-ppl") {
params.compute_ppl = false;
return true;
diff --git a/common/common.h b/common/common.h
index 9a1dc4a2..bb45b3b4 100644
--- a/common/common.h
+++ b/common/common.h
@@ -224,6 +224,7 @@ struct gpt_params {
// imatrix params
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
+ std::string output_tensor_name = "output.weight"; // name of the output tensor
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations