summaryrefslogtreecommitdiff
path: root/ggml-quants.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-quants.c')
-rw-r--r--ggml-quants.c26
1 files changed, 23 insertions, 3 deletions
diff --git a/ggml-quants.c b/ggml-quants.c
index 0971d696..061edddc 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -1318,7 +1318,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
wasm_i32x4_extract_lane(accv, 3)));
}
#elif defined(__AVX2__) || defined(__AVX__)
+ block_q8_1_x4 * restrict y4 = vy;
+ int nb4 = 4*(nb/4);
+#ifdef __AVX2__
+ const bool pack = true;
+#else
+ const bool pack = false;
+#endif
for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
@@ -1340,7 +1348,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
// Quantize these floats
const float d = max_scalar / 127.f;
- y[i].d = GGML_FP32_TO_FP16(d);
+ if (pack && i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
@@ -1364,7 +1376,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s
- y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ } else {
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ }
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -1378,7 +1394,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ }
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE