diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-08-09 10:32:07 +0300 |
---|---|---|
committer | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-09 16:00:31 +0200 |
commit | f0d7a0d53b0ecdd43ba85bcd49309b291372ca67 (patch) | |
tree | 9af6feced146997d2c93daa200c75f9a890dd016 /ggml/src/ggml-cuda | |
parent | c77dba5273777c6c43d9745fc96114eba867f6c2 (diff) |
Fix Zen4 implementation of iq3_k, iq4_k, iq5_k
See comments in f3a823ce729a7db33e7d4375eae7291bbe6196db
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 22 |
1 files changed, 8 insertions, 14 deletions
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index db5fd2dd..f76c80dc 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -591,12 +591,6 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_ } } -#define A_IQ6K -127.f -#define B_IQ6K 6.2568f -#define C_IQ6K 0.11218f -#define D_IQ6K 0.0011972f -#define S_IQ6K 1 - template<typename dst_t> static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -617,14 +611,14 @@ static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_ const uint8_t extra = x[i].extra >> 4*(ib64%4); for (int j = 0; j < 2; ++j) { const uint8_t h1 = qh[j] >> 4*(ib64%2), h2 = qh[j+16] >> 4*(ib64%2); - float q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4); - float q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); - float q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); - float q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); - y[j+ 0] = dl1 * (A_IQ6K + q1*(B_IQ6K + q1*(-C_IQ6K + q1*D_IQ6K)) + (extra & 1 ? S_IQ6K : 0)); - y[j+16] = dl2 * (A_IQ6K + q2*(B_IQ6K + q2*(-C_IQ6K + q2*D_IQ6K)) + (extra & 2 ? S_IQ6K : 0)); - y[j+32] = dl3 * (A_IQ6K + q3*(B_IQ6K + q3*(-C_IQ6K + q3*D_IQ6K)) + (extra & 4 ? S_IQ6K : 0)); - y[j+48] = dl4 * (A_IQ6K + q4*(B_IQ6K + q4*(-C_IQ6K + q4*D_IQ6K)) + (extra & 8 ? S_IQ6K : 0)); + uint8_t q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4); + uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); + uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); + uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); + y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); + y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); + y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); + y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); } } |