| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
 | // Unit tests for quantization specific functions - quantize, dequantize and dot product
#include "ggml.h"
#undef NDEBUG
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <string>
#include <vector>
#include <random>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
static const char* RESULT_STR[] = {"ok", "FAILED"};
// Generate synthetic data
static void generate_data(float offset, size_t n, float * dst) {
    for (size_t i = 0; i < n; i++) {
        dst[i] = 0.1 + 2*cosf(i + offset);
    }
}
static void generate_bitnet_data(size_t n, float * dst) {
    std::mt19937 rndm(1234);
    for (size_t i = 0; i < n; i++) {
        auto r = rndm();
        dst[i] = r > std::mt19937::max()/2 ? 0.f : r < std::mt19937::max()/4 ? -1.f : 1.f;
    }
}
// Calculate RMSE between two float arrays
static float array_rmse(const float * a1, const float * a2, size_t n) {
    double sum = 0;
    for (size_t i = 0; i < n; i++) {
        double diff = a1[i] - a2[i];
        sum += diff * diff;
    }
    return sqrtf(sum) / n;
}
// Total quantization error on test data
static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
    std::vector<uint8_t> tmp_q(2*test_size);
    std::vector<float> tmp_out(test_size);
    qfns.from_float(test_data, tmp_q.data(), test_size);
    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
    return array_rmse(test_data, tmp_out.data(), test_size);
}
// Total quantization error on test data
static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
    std::vector<uint8_t> tmp_q(2*test_size);
    std::vector<float> tmp_out(test_size);
    std::vector<float> tmp_out_ref(test_size);
    qfns.from_float(test_data, tmp_q.data(), test_size);
    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
    qfns.from_float_ref(test_data, tmp_q.data(), test_size);
    qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
    return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
}
static float dot_product(const float * a1, const float * a2, size_t test_size) {
    double sum = 0;
    for (size_t i = 0; i < test_size; i++) {
        sum += a1[i] * a2[i];
    }
    return sum;
}
// Total dot product error
static float dot_product_error(
    ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
) {
    std::vector<uint8_t> tmp_q1(2*test_size);
    std::vector<uint8_t> tmp_q2(2*test_size);
    auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
    qfns.from_float_ref(test_data1, tmp_q1.data(), test_size);
    vdot.from_float(test_data2, tmp_q2.data(), test_size);
    float result = INFINITY;
    qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
    const float dot_ref = dot_product(test_data1, test_data2, test_size);
    return fabsf(result - dot_ref) / test_size;
}
int main(int argc, char * argv[]) {
    bool verbose = false;
    const size_t test_size = 32 * 128;
    std::string arg;
    for (int i = 1; i < argc; i++) {
        arg = argv[i];
        if (arg == "-v") {
            verbose = true;
        } else {
            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
            return 1;
        }
    }
    std::vector<float> test_data(test_size);
    std::vector<float> test_data2(test_size);
    std::vector<float> test_data_bitnet(test_size);
    generate_data(0.0, test_data.size(), test_data.data());
    generate_data(1.0, test_data2.size(), test_data2.data());
    generate_bitnet_data(test_data_bitnet.size(), test_data_bitnet.data());
    // Initialize GGML, ensures float conversion tables are initialized
    struct ggml_init_params ggml_params = {
        /* .mem_size   = */ 1*1024,
        /* .mem_buffer = */ NULL,
        /* .no_alloc   = */ true,
    };
    struct ggml_context * ctx = ggml_init(ggml_params);
    int num_failed = 0;
    bool failed = false;
    for (int i = 0; i < GGML_TYPE_COUNT; i++) {
        ggml_type type = (ggml_type) i;
        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
        // deprecated - skip
        if (qfns.blck_size == 0) {
            continue;
        }
        auto test_data_quantize = test_data.data();
        auto test_data_vecdot   = test_data2.data();
        const ggml_type ei = (ggml_type)i;
        if (ei == GGML_TYPE_IQ1_BN || ei == GGML_TYPE_IQ2_BN) {
            test_data_quantize = test_data_bitnet.data();
            test_data_vecdot   = test_data_bitnet.data();
            //printf("Skipping %s because test data does not satisfy Bitnet requirements\n", ggml_type_name(ei));
            //continue;
        }
        printf("Testing %s\n", ggml_type_name((ggml_type) i));
        ggml_quantize_init(ei);
        if (qfns.from_float && qfns.to_float) {
            const float total_error = total_quantization_error(qfns, test_size, test_data_quantize);
            const float max_quantization_error =
                type == GGML_TYPE_Q2_K    ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
                type == GGML_TYPE_IQ2_S   ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
                type == GGML_TYPE_Q3_K    ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
                type == GGML_TYPE_IQ3_S   ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
                type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
            failed = !(total_error < max_quantization_error);
            num_failed += failed;
            if (failed || verbose) {
                printf("%5s absolute quantization error:    %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
            }
            const float reference_error = reference_quantization_error(qfns, test_size, test_data_quantize);
            failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
            num_failed += failed;
            if (failed || verbose) {
                printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
            }
            const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data_vecdot);
            const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
                                            type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
                                          ? MAX_DOT_PRODUCT_ERROR_LOWBIT
                                          : MAX_DOT_PRODUCT_ERROR;
            failed = !(vec_dot_error < max_allowed_error);
            num_failed += failed;
            if (failed || verbose) {
                printf("%5s dot product error:              %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
            }
        }
    }
    if (num_failed || verbose) {
        printf("%d tests failed\n", num_failed);
    }
    ggml_free(ctx);
    return num_failed > 0;
}
 |