summaryrefslogtreecommitdiff
path: root/examples/imatrix/imatrix.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/imatrix/imatrix.cpp')
-rw-r--r--examples/imatrix/imatrix.cpp77
1 files changed, 44 insertions, 33 deletions
diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp
index 98c0e93e..71e7a727 100644
--- a/examples/imatrix/imatrix.cpp
+++ b/examples/imatrix/imatrix.cpp
@@ -23,6 +23,7 @@ struct Stats {
};
struct StatParams {
+ std::string dataset;
std::string ofile = "imatrix.dat";
int n_output_frequency = 10;
int verbosity = 1;
@@ -46,7 +47,7 @@ private:
std::vector<float> m_src1_data;
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
//
- void save_imatrix(const char * file_name) const;
+ void save_imatrix(const char * file_name, const char * dataset) const;
void keep_imatrix(int ncall) const;
};
@@ -199,7 +200,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
}
void IMatrixCollector::save_imatrix() const {
- save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str());
+ save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(), m_params.dataset.c_str());
}
void IMatrixCollector::keep_imatrix(int ncall) const {
@@ -207,24 +208,33 @@ void IMatrixCollector::keep_imatrix(int ncall) const {
if (file_name.empty()) file_name = "imatrix.dat";
file_name += ".at_";
file_name += std::to_string(ncall);
- save_imatrix(file_name.c_str());
+ save_imatrix(file_name.c_str(), m_params.dataset.c_str());
}
-void IMatrixCollector::save_imatrix(const char * fname) const {
+void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) const {
std::ofstream out(fname, std::ios::binary);
int n_entries = m_stats.size();
- out.write((const char*)&n_entries, sizeof(n_entries));
- for (auto& p : m_stats) {
+ out.write((const char *) &n_entries, sizeof(n_entries));
+ for (const auto & p : m_stats) {
int len = p.first.size();
- out.write((const char*)&len, sizeof(len));
+ out.write((const char *) &len, sizeof(len));
out.write(p.first.c_str(), len);
- out.write((const char*)&p.second.ncall, sizeof(p.second.ncall));
+ out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size();
- out.write((const char*)&nval, sizeof(nval));
- if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float));
+ out.write((const char *) &nval, sizeof(nval));
+ if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
}
+
+ // Write the number of call the matrix was computed with
+ out.write((const char *) &m_last_call, sizeof(m_last_call));
+
+ // Write the dataset name at the end of the file to later on specify it in quantize
+ int n_dataset = strlen(dataset);
+ out.write((const char *) &n_dataset, sizeof(n_dataset));
+ out.write(dataset, n_dataset);
+
if (m_params.verbosity > 0) {
- fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n",__func__,m_last_call,fname);
+ fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname);
}
}
@@ -547,6 +557,29 @@ int main(int argc, char ** argv) {
}
}
+ gpt_params params;
+ params.n_batch = 512;
+ if (!gpt_params_parse(args.size(), args.data(), params)) {
+ return 1;
+ }
+
+ params.logits_all = true;
+ params.n_batch = std::min(params.n_batch, params.n_ctx);
+
+ print_build_info();
+
+ if (params.seed == LLAMA_DEFAULT_SEED) {
+ params.seed = time(NULL);
+ }
+
+ fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
+
+ std::mt19937 rng(params.seed);
+ if (params.random_prompt) {
+ params.prompt = gpt_random_prompt(rng);
+ }
+
+ sparams.dataset = params.prompt_file;
g_collector.set_parameters(std::move(sparams));
if (!combine_files.empty()) {
@@ -585,28 +618,6 @@ int main(int argc, char ** argv) {
}
}
- gpt_params params;
- params.n_batch = 512;
- if (!gpt_params_parse(args.size(), args.data(), params)) {
- return 1;
- }
-
- params.logits_all = true;
- params.n_batch = std::min(params.n_batch, params.n_ctx);
-
- print_build_info();
-
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
-
- fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
-
- std::mt19937 rng(params.seed);
- if (params.random_prompt) {
- params.prompt = gpt_random_prompt(rng);
- }
-
llama_backend_init();
llama_numa_init(params.numa);