diff options
Diffstat (limited to 'ggml/src/iqk/iqk_gemm_ktquants.cpp')
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.cpp | 176 |
1 files changed, 83 insertions, 93 deletions
diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 0529128c..6604480d 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -89,7 +89,8 @@ struct Trellis2 { const __m256i mask2 = _mm256_set1_epi32(km32); inline __m256i next8(uint32_t val1, uint32_t val2) { - __m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2); + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + //__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2); __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); return _mm256_xor_si256(_mm256_and_si256(mres, _mm256_set1_epi32(kmask)), _mm256_set1_epi32(km32)); } @@ -163,28 +164,6 @@ static inline __m256 abs_ps(__m256 vals) { return _mm256_andnot_ps(sign_bit, vals); } -// Negates 32-bit float lanes of an 8x32-bit vector -// based on 8x8-bit condition var. For float lane i, if byte i of -// `condition` is nonzero, the float will be negated. -static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) { - __m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64); - // Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero, - // else 0x00 (upper bytes are meaningless) - __m128i zeros = _mm_setzero_si128(); - __m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros); - __m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros); - // Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros - // expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise - __m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask); - // Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask) - __m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256()); - // Negate via XOR on sign bits of each 32-bit float - __m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value - __m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern); - __m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32); - return _mm256_xor_ps(vals, xor_mask_ps); -} - template <int nrc_y> static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -192,6 +171,14 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn Trellis1 trellis; + union { __m256 vec; float val[8]; } s_helper; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + + __m256i all_signs[4]; + auto mask1 = _mm256_set1_epi32(0x01); + auto mask2 = _mm256_set1_epi32(0x10); + __m256 accd[nrc_y]; const float * y[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); @@ -206,31 +193,28 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn for (int i = 0; i < nb; ++i) { const uint16_t * ql = (const uint16_t *)x[i].ql; const uint8_t * qh = x[i].qh; - for (int j = 0; j < 128; j+=8) { - uint64_t mask1 = 0x0101010101010101 << (j/32); - uint64_t mask2 = mask1 << 4; - uint32_t val1 = ql[j/8] + 4096; - uint32_t val2 = ql[j/8+16] + 4096; - const uint64_t signs = *((const uint64_t *)(qh + (j%32))); - const float x_scale1 = (x[i].scales[j/32] & 0xf); - const float x_scale2 = (x[i].scales[j/32] >> 4); - const __m256 x_val1 = abs_ps(trellis_gen8(trellis.next8(val1))); - const __m256 x_val2 = abs_ps(trellis_gen8(trellis.next8(val2))); - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps( - conditional_negate_ps( - _mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1 - ), - _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), - accd[iy] - ); - accd[iy] = _mm256_fmadd_ps( - conditional_negate_ps( - _mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2 - ), - _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), - accd[iy] - ); + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_cvtepi32_ps(s32); + for (int j = 0; j < 4; ++j) all_signs[j] = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*j))); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); + for (int j = 0; j < 4; ++j) { + uint32_t val1 = ql[4*ib+j ] + 4096; + uint32_t val2 = ql[4*ib+j+16] + 4096; + auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask1), mask1), _mm256_set1_epi32(0x80000000)); + auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask2), mask2), _mm256_set1_epi32(0x80000000)); + all_signs[j] = _mm256_srli_epi32(all_signs[j], 1); + auto x_val1 = abs_ps(trellis_gen8(trellis.next8(val1))); + auto x_val2 = abs_ps(trellis_gen8(trellis.next8(val2))); + x_val1 = _mm256_mul_ps(scale1, _mm256_xor_ps(x_val1, _mm256_castsi256_ps(sign1))); + x_val2 = _mm256_mul_ps(scale2, _mm256_xor_ps(x_val2, _mm256_castsi256_ps(sign2))); + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j ), x_val1, accd[iy]); + accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128), x_val2, accd[iy]); + } } } } @@ -250,66 +234,72 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn Trellis2 trellis; - __m256 accd[nrc_y]; - __m256 accd2[nrc_y]; + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + + __m256 accd[k_acc]; const float * y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + float row_sum[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const float *)info.src1_row(iy); + auto sum = _mm256_setzero_ps(); + for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i)); + row_sum[iy] = hsum_float_8(sum); + } for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - const float d = dptr[0] * 31.75f * 1.01f; - const float row_av = dptr[1]; + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto dav = dptr[1]; const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_setzero_ps(); - accd2[iy] = _mm256_setzero_ps(); - } + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); const uint32_t * shb = x[i].qs; const uint8_t * ql = (const uint8_t *)(shb + 8); const uint8_t * qh = ql + kNumGroups; - for (int j = 0; j < 128; j+=8) { - const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15); - const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15); - const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64; - const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64; - const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); - const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); - uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3)); - const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4)); - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps( - _mm256_load_ps(y[iy] + i*QK_K+j), - _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), - accd[iy] - ); - accd[iy] = _mm256_fmadd_ps( - _mm256_load_ps(y[iy] + i*QK_K+j+128), - _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), - accd[iy] - ); - accd2[iy] = _mm256_add_ps( - _mm256_load_ps(y[iy] + i*QK_K+j), - accd2[iy] - ); - accd2[iy] = _mm256_add_ps( - _mm256_load_ps(y[iy] + i*QK_K+j+128), - accd2[iy] - ); + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); + auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); + if constexpr (nrc_y == 1) { + auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0); + auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128); + accd[0] = _mm256_fmadd_ps(y1, x_val1, accd[0]); + accd[1] = _mm256_fmadd_ps(y2, x_val2, accd[1]); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+ 0); + auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128); + accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]); + accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]); + } + } } } } - for (int iy = 0; iy < nrc_y; ++iy) { - __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]); - __m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]); - info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2)); + if constexpr (nrc_y == 1) { + info.store(ix, 0, hsum_float_8(_mm256_add_ps(accd[0], accd[1])) + dav*row_sum[0]); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy]) + dav*row_sum[iy]); + } } } } |