diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 97 |
1 files changed, 55 insertions, 42 deletions
@@ -366,9 +366,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; + const size_t bs = sizeof(float) + QK/2; - float * restrict pd = (float *) (y); - uint8_t * restrict pb = (uint8_t *) (pd + nb); + uint8_t * restrict pd = (uint8_t *) (y + 0*bs); + uint8_t * restrict pb = (uint8_t *) (y + 0*bs + sizeof(float)); uint8_t pp[QK/2]; @@ -395,7 +396,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0/d : 0.0; - pd[i] = d; + *(float *)pd = d; + pd += bs; for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); @@ -406,7 +408,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); } - memcpy(pb + i*16, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } #else #error "not implemented for QK" @@ -434,7 +437,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0/d : 0.0; - pd[i] = d; + *(float *)pd = d; + pd += bs; for (int l = 0; l < 8; l++) { const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); @@ -445,7 +449,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); } - memcpy(pb + i*16, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } #else #error "not implemented for QK" @@ -463,7 +468,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { const float d = amax / ((1 << 3) - 1); const float id = d ? 1.0f/d : 0.0f; - pd[i] = d; + *(float *)pd = d; + pd += bs; for (int l = 0; l < QK; l += 2) { const float v0 = x[i*QK + l + 0]*id; @@ -478,7 +484,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*QK/2, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } #endif } @@ -535,15 +542,16 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; + const size_t bs = sizeof(float) + QK/2; - const float * restrict pd = (const float *) (x); - const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs); + const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float)); // scalar for (int i = 0; i < nb; i++) { - const float d = pd[i]; + const float d = *(const float *) (pd + i*bs); - const uint8_t * restrict pp = pb + i*QK/2; + const uint8_t * restrict pp = pb + i*bs; for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; @@ -554,6 +562,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const float v0 = (vi0 - 8)*d; const float v1 = (vi1 - 8)*d; + //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1); + y[i*QK + l + 0] = v0; y[i*QK + l + 1] = v1; @@ -1179,11 +1189,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void assert(n % QK == 0); assert(nb % 2 == 0); - const float * restrict pd0 = (const float *) x; - const float * restrict pd1 = (const float *) y; + const size_t bs = sizeof(float) + QK/2; - const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); - const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + const uint8_t * restrict pd0 = (const uint8_t *) (x + 0*bs); + const uint8_t * restrict pd1 = (const uint8_t *) (y + 0*bs); + + const uint8_t * restrict pb0 = (const uint8_t *) (x + 0*bs + sizeof(float)); + const uint8_t * restrict pb1 = (const uint8_t *) (y + 0*bs + sizeof(float)); float sumf = 0.0; @@ -1193,23 +1205,23 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sum1 = 0.0f; for (int i = 0; i < nb; i += 2) { - const float d0_0 = pd0[i + 0]; - const float d1_0 = pd1[i + 0]; - const float d0_1 = pd0[i + 1]; - const float d1_1 = pd1[i + 1]; + const float d0_0 = *(const float *) (pd0 + i*bs); + const float d1_0 = *(const float *) (pd1 + i*bs); + const float d0_1 = *(const float *) (pd0 + (i + 1)*bs); + const float d1_1 = *(const float *) (pd1 + (i + 1)*bs); //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1); - const uint8_t * restrict p0 = pb0 + i*16; - const uint8_t * restrict p1 = pb1 + i*16; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; const uint8x16_t m4b = vdupq_n_u8(0xf); const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(p0); const uint8x16_t v1_0 = vld1q_u8(p1); - const uint8x16_t v0_1 = vld1q_u8(p0 + 16); - const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + const uint8x16_t v0_1 = vld1q_u8(p0 + bs); + const uint8x16_t v1_1 = vld1q_u8(p1 + bs); // 4-bit -> 8-bit const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); @@ -1280,21 +1292,21 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sum1 = 0.0f; for (int i = 0; i < nb; i += 2) { - const float d0_0 = pd0[i + 0]; - const float d0_1 = pd0[i + 1]; - const float d1_0 = pd1[i + 0]; - const float d1_1 = pd1[i + 1]; + const float d0_0 = *(const float *) (pd0 + i*bs); + const float d1_0 = *(const float *) (pd1 + i*bs); + const float d0_1 = *(const float *) (pd0 + (i + 1)*bs); + const float d1_1 = *(const float *) (pd1 + (i + 1)*bs); - const uint8_t * restrict p0 = pb0 + i*16; - const uint8_t * restrict p1 = pb1 + i*16; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; const v128_t m4b = wasm_u8x16_splat(0xf); const v128_t s8b = wasm_i8x16_splat(0x8); const v128_t v0_0 = wasm_v128_load(p0); - const v128_t v0_1 = wasm_v128_load(p0 + 16); + const v128_t v0_1 = wasm_v128_load(p0 + bs); const v128_t v1_0 = wasm_v128_load(p1); - const v128_t v1_1 = wasm_v128_load(p1 + 16); + const v128_t v1_1 = wasm_v128_load(p1 + bs); // 4-bit -> 8-bit const v128_t v0_0l = wasm_v128_and(v0_0, m4b); @@ -1363,11 +1375,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void #else // scalar for (int i = 0; i < nb; i++) { - const float d0 = pd0[i]; - const float d1 = pd1[i]; + const float d0 = *(const float *) (pd0 + i*bs); + const float d1 = *(const float *) (pd1 + i*bs); - const uint8_t * restrict p0 = pb0 + i*QK/2; - const uint8_t * restrict p1 = pb1 + i*QK/2; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; for (int j = 0; j < QK/2; j++) { const uint8_t v0 = p0[j]; @@ -1552,16 +1564,17 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res assert(n % QK == 0); const int nb = n / QK; + const size_t bs = sizeof(float) + QK/2; - const float * restrict pd = (const float *) (x); - const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs); + const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float)); #if __ARM_NEON #if QK == 32 for (int i = 0; i < nb; ++i) { - const float d0 = pd[i]*v; + const float d0 = v*(*(const float *) (pd + i*bs)); - const uint8_t * restrict pp = pb + i*16; + const uint8_t * restrict pp = pb + i*bs; const uint8x8_t m4b = vdup_n_u8(0xf); const int8x8_t s8b = vdup_n_s8(0x8); @@ -1615,9 +1628,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res #else // scalar for (int i = 0; i < nb; i++) { - const float d = pd[i]; + const float d = *(const float *) (pd + i*bs); - const uint8_t * restrict pp = pb + i*QK/2; + const uint8_t * restrict pp = pb + i*bs; for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; |