summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-06-09 20:19:35 +0300
committerGitHub <noreply@github.com>2024-06-09 20:19:35 +0300
commite95beeb1fc4621826ddd616776dbdf717366bf5c (patch)
tree38dc664d1ebfb2459dafbc6d934a7cc678379911
parent57bf62ce7cb75cca589943e2050d29bff4026e76 (diff)
imatrix : handle partial entries (#7833)
-rw-r--r--examples/imatrix/imatrix.cpp58
1 files changed, 51 insertions, 7 deletions
diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp
index e18f4956..574f5ed9 100644
--- a/examples/imatrix/imatrix.cpp
+++ b/examples/imatrix/imatrix.cpp
@@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
fname += std::to_string(ncall);
}
+ // avoid writing imatrix entries that do not have full data
+ // this can happen with MoE models where some of the experts end up not being exercised by the provided training data
+
+ int n_entries = 0;
+ std::vector<std::string> to_store;
+
+ bool is_first = true; // for printing
+ for (const auto & kv : m_stats) {
+ const int n_all = kv.second.counts.size();
+
+ if (n_all == 0) {
+ continue;
+ }
+
+ int n_zeros = 0;
+ for (const int c : kv.second.counts) {
+ if (c == 0) {
+ n_zeros++;
+ }
+ }
+
+ if (n_zeros != 0 && is_first) {
+ fprintf(stderr, "\n");
+ is_first = false;
+ }
+
+ if (n_zeros == n_all) {
+ fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
+ continue;
+ }
+
+ if (n_zeros > 0) {
+ fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
+ continue;
+ }
+
+ n_entries++;
+ to_store.push_back(kv.first);
+ }
+
+ if (to_store.size() < m_stats.size()) {
+ fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
+ }
+
std::ofstream out(fname, std::ios::binary);
- int n_entries = m_stats.size();
out.write((const char *) &n_entries, sizeof(n_entries));
- for (const auto & p : m_stats) {
- int len = p.first.size();
+ for (const auto & name : to_store) {
+ const auto & stat = m_stats.at(name);
+ int len = name.size();
out.write((const char *) &len, sizeof(len));
- out.write(p.first.c_str(), len);
- out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
- int nval = p.second.values.size();
+ out.write(name.c_str(), len);
+ out.write((const char *) &stat.ncall, sizeof(stat.ncall));
+ int nval = stat.values.size();
out.write((const char *) &nval, sizeof(nval));
if (nval > 0) {
std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) {
- tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
+ tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
}
out.write((const char*)tmp.data(), nval*sizeof(float));
}