summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-08-09 10:32:07 +0300
committerKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-09 16:00:31 +0200
commitf0d7a0d53b0ecdd43ba85bcd49309b291372ca67 (patch)
tree9af6feced146997d2c93daa200c75f9a890dd016 /ggml/src/ggml-cuda
parentc77dba5273777c6c43d9745fc96114eba867f6c2 (diff)
Fix Zen4 implementation of iq3_k, iq4_k, iq5_k
See comments in f3a823ce729a7db33e7d4375eae7291bbe6196db
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/convert.cu22
1 files changed, 8 insertions, 14 deletions
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index db5fd2dd..f76c80dc 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -591,12 +591,6 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_
}
}
-#define A_IQ6K -127.f
-#define B_IQ6K 6.2568f
-#define C_IQ6K 0.11218f
-#define D_IQ6K 0.0011972f
-#define S_IQ6K 1
-
template<typename dst_t>
static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@@ -617,14 +611,14 @@ static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_
const uint8_t extra = x[i].extra >> 4*(ib64%4);
for (int j = 0; j < 2; ++j) {
const uint8_t h1 = qh[j] >> 4*(ib64%2), h2 = qh[j+16] >> 4*(ib64%2);
- float q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4);
- float q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4);
- float q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2);
- float q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2);
- y[j+ 0] = dl1 * (A_IQ6K + q1*(B_IQ6K + q1*(-C_IQ6K + q1*D_IQ6K)) + (extra & 1 ? S_IQ6K : 0));
- y[j+16] = dl2 * (A_IQ6K + q2*(B_IQ6K + q2*(-C_IQ6K + q2*D_IQ6K)) + (extra & 2 ? S_IQ6K : 0));
- y[j+32] = dl3 * (A_IQ6K + q3*(B_IQ6K + q3*(-C_IQ6K + q3*D_IQ6K)) + (extra & 4 ? S_IQ6K : 0));
- y[j+48] = dl4 * (A_IQ6K + q4*(B_IQ6K + q4*(-C_IQ6K + q4*D_IQ6K)) + (extra & 8 ? S_IQ6K : 0));
+ uint8_t q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4);
+ uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4);
+ uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2);
+ uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2);
+ y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0));
+ y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0));
+ y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0));
+ y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0));
}
}