summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-19 13:49:28 +0300
committerGitHub <noreply@github.com>2024-08-19 13:49:28 +0300
commit5652100afcc423cf6342778cde372ca6aa54a79b (patch)
tree69c36e6554bef43977153b41c9ca2042cb47337e
parentc7b47fc67f23d1296b5b803337c27d8534373161 (diff)
quantize_stats: print rmse and max error as fraction of <x> (#21)
This allows for a better comparison between different models or different tensors of the same model where the magnitude of the model weights may differ. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--examples/quantize-stats/quantize-stats.cpp7
1 files changed, 6 insertions, 1 deletions
diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp
index 837a17c7..6264deb4 100644
--- a/examples/quantize-stats/quantize-stats.cpp
+++ b/examples/quantize-stats/quantize-stats.cpp
@@ -40,6 +40,7 @@ struct error_stats {
size_t num_samples;
double total_error;
double max_error;
+ double sum_x2;
uint64_t error_histogram[HISTOGRAM_BUCKETS];
};
@@ -89,6 +90,7 @@ static void update_error_stats(int64_t nelements, const float * input, const flo
double diff = input[i] - output[i];
stats.total_error += diff * diff;
stats.max_error = fmax(fabs(diff), stats.max_error);
+ stats.sum_x2 += input[i]*input[i];
stats.error_histogram[std::max(std::min((size_t) floor(fabs(diff) / HISTOGRAM_RANGE * HISTOGRAM_BUCKETS), HISTOGRAM_BUCKETS-1), (size_t) 0)]++;
}
stats.num_samples += nelements;
@@ -97,6 +99,7 @@ static void update_error_stats(int64_t nelements, const float * input, const flo
static void combine_error_stats(error_stats & into, const error_stats & from) {
into.num_samples += from.num_samples;
into.total_error += from.total_error;
+ into.sum_x2 += from.sum_x2;
if (from.max_error > into.max_error) into.max_error = from.max_error;
for (size_t i=0; i<HISTOGRAM_BUCKETS; ++i) into.error_histogram[i] += from.error_histogram[i];
}
@@ -116,9 +119,11 @@ static double find_quantile(const error_stats & stats, double quantile) {
static void print_error_stats(const std::string & name, const error_stats & stats, bool print_histogram) {
double rmse = sqrt(stats.total_error / (double) stats.num_samples);
+ double av_x = sqrt(stats.sum_x2 / (double) stats.num_samples);
double median = find_quantile(stats, .5);
double pct95 = find_quantile(stats, .95);
- printf("%-50s: rmse %.8f, maxerr %.8f, 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, stats.max_error, pct95, median);
+ printf("%-40s: rmse %.8f, %.6f maxerr %.8f, %.6f 95pct<%.4f, median<%.4f\n", name.c_str(), rmse, rmse/av_x,
+ stats.max_error, stats.max_error/av_x, pct95, median);
if (print_histogram) {
printf("Error distribution:\n");
for (size_t i = 0; i < HISTOGRAM_BUCKETS; i++) {