summaryrefslogtreecommitdiff
path: root/utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils.cpp')
-rw-r--r--utils.cpp20
1 files changed, 12 insertions, 8 deletions
diff --git a/utils.cpp b/utils.cpp
index aa3ad105..26e313d5 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -489,7 +489,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
const int nb = k / qk;
- const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*qk/2);
+ const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2);
+ const size_t row_size = nb*bs;
assert(k % qk == 0);
@@ -498,10 +499,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
char * pdst = (char *) dst;
- for (int j = 0; j < n; j += k) {
- float * pm = (float *) (pdst + (j/k)*row_size);
- float * pd = (float *) (pm + nb);
- uint8_t * pb = (uint8_t *) (pd + nb);
+ for (int j = 0; j < n; j += k) {
+ uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
+ uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
+ uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
//printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
@@ -519,8 +520,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
- pm[i] = min;
- pd[i] = d;
+ *(float *) pd = d;
+ *(float *) pm = min;
+ pd += bs;
+ pm += bs;
for (int l = 0; l < qk; l += 2) {
const float v0 = (src[j + i*qk + l + 0] - min)*id;
@@ -538,7 +541,8 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
pp[l/2] = vi0 | (vi1 << 4);
}
- memcpy(pb + i*qk/2, pp, pp_size);
+ memcpy(pb, pp, pp_size);
+ pb += bs;
}
}
}