summaryrefslogtreecommitdiff
path: root/ggml-cuda/convert.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda/convert.cu')
-rw-r--r--ggml-cuda/convert.cu26
1 files changed, 21 insertions, 5 deletions
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index a0a826e9..4a67c498 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -425,7 +425,10 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
const int64_t ii = blockIdx.x;
const block_iq1_bn * x = (const block_iq1_bn *) vx;
- static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+//#define COMPUTE_VS(v) 3*v >> 8
+#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
const int tid = threadIdx.x;
const int il = tid/4; // 0...7
@@ -433,11 +436,24 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst
dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
int64_t i = QK_K/QK_IQ1BN * ii + ib;
if (i >= nb64) return;
- uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096);
- for (int j = 0; j < 8; ++j) {
- uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13;
- y[j] = v - 1;
+ const int i16 = il/2;
+ uint8_t q = x[i].ql[3*i16+2*(il%2)];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[2*(il%2)+j] = vs - 1;
}
+ q = x[i].ql[3*i16+1];
+ for (int j = 0; j < 2; ++j) {
+ uint8_t v = k_mult[3*(il%2)+j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[5*(1-(il%2))+j] = vs-1;
+ }
+ uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[7] = vs - 1;
+
+#undef COMPUTE_VS
}
template<typename dst_t>