diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 47ab92f0..db5fd2dd 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -591,6 +591,43 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_ } } +#define A_IQ6K -127.f +#define B_IQ6K 6.2568f +#define C_IQ6K 0.11218f +#define D_IQ6K 0.0011972f +#define S_IQ6K 1 + +template<typename dst_t> +static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq6_k * x = (const block_iq6_k *) vx; + + const int tid = threadIdx.x; + int ib64 = tid/8; // 0...3 + int il = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 64*ib64 + 2*il; + const float d = (float)x[i].d; + const float dl1 = d * x[i].scales[4*ib64+0]; + const float dl2 = d * x[i].scales[4*ib64+1]; + const float dl3 = d * x[i].scales[4*ib64+2]; + const float dl4 = d * x[i].scales[4*ib64+3]; + const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; + const uint8_t * qh = x[i].qh + 32*(ib64/2) + 2*il; + const uint8_t extra = x[i].extra >> 4*(ib64%4); + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 4*(ib64%2), h2 = qh[j+16] >> 4*(ib64%2); + float q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4); + float q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); + float q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); + float q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); + y[j+ 0] = dl1 * (A_IQ6K + q1*(B_IQ6K + q1*(-C_IQ6K + q1*D_IQ6K)) + (extra & 1 ? S_IQ6K : 0)); + y[j+16] = dl2 * (A_IQ6K + q2*(B_IQ6K + q2*(-C_IQ6K + q2*D_IQ6K)) + (extra & 2 ? S_IQ6K : 0)); + y[j+32] = dl3 * (A_IQ6K + q3*(B_IQ6K + q3*(-C_IQ6K + q3*D_IQ6K)) + (extra & 4 ? S_IQ6K : 0)); + y[j+48] = dl4 * (A_IQ6K + q4*(B_IQ6K + q4*(-C_IQ6K + q4*D_IQ6K)) + (extra & 8 ? S_IQ6K : 0)); + } +} + template<typename dst_t> static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -803,6 +840,12 @@ static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y); } +template<typename dst_t> +static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq6_k<<<nb, 32, 0, stream>>>(vx, y); +} + template <typename src_t, typename dst_t> static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -877,6 +920,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: return dequantize_row_iq5_k_cuda; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: @@ -938,6 +983,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_k_cuda; case GGML_TYPE_IQ5_K: return dequantize_row_iq5_k_cuda; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: |