From a1c931c30ce9c5618ec56fe93234110343111710 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 22 May 2025 23:17:52 -0700 Subject: Trellis quants with CPU inference (#441) * WIP * WIP * WIP * Testing Trellis quantization Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable. * Testing Trellis quantization: 4-bit quantized block scales rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw. * Testing Trellis quantization: playing with scales and generators * iq2_kt: quantize / dequantize I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B). * iq2_kt: CUDA dequantize so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs. * WIP * WIP * WIP - try larger blocks With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster. * iq2_kt - this is better Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower. * iq2_kt - even better Re-quantize after determining block scales (at the epxense of much longer quantization time). * iq2_kt: CUDA dot product Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU. * iq2_kt: very slightly faster CUDA dot product * iq2_kt: f16 CUDA dot product We arrive at 112 t/s. * iq2_kt: faster f16 CUDA dot product We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance. * iq2_kt: faster f16 CUDA dot product We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16. * Minor * Adding iq3_kt 3.125 bpw. So far does not look good on the PPL vs bpw plot. * Forgotten change * WIP * WIP * iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants. * WIP * iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892 * iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v. * iq3_kt WIP: speed up quantization Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX. * iq3_kt speed up quantization Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds! * iq3_kt: CUDA dot product * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406 PPL(LLaMA-2-7B, 4096) = 6.4179 * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 * Adding iq4_kt - not competitive at this point * WIP * WIP * iq4_kt: CUDA dot product * iq4_kt: minor tweaks * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297 PPL(LLaMA-2-7B, 4096) = 6.3913 Ah, quantization is faster too. About 20% faster. * iq3_kt: small improvements and faster quantization * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627 PPL(LLaMA-2-7B, 4096) = 6.3825 Quantization is faster too: ~200 seconds for LLaMA-3.1-8B on Ryzen-5975WX. * iq3_kt: small progress * WIP * iq4_kt: go to 4.0 bpw 15 bits per group of 4, plus 8 bit scales ifor blocks of 32. This gives a slightly better PPL than iq4_kss. * iq4_kt: very slightly better at the expense of much longer quantization time. * iq4_kt: failed attemt to adjust CUDA dot product It was working for 4.125 bpw. But after changing to 4.0 bpw there is something wrong and I don't see the bug. * DRY * DRY * iq4_kt: CUDA dot product works * DRY * Report actual bpw * Minor tweaks * Checkpoint Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude plus 1 bpw for the sign. It goves a visible improvement in the PPL vs bpw plot, but that comes at the expense of much longer quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX). I also notices that the 3INST generator is not actually generating a Gaussian distribution. But going to a better generator means readjusting all the hyper-parameters, so leaving it for later. * WIP for IQ2_KT * WIP - working basic iq2_kt * still super slow (0.17t/s eval) * flatten 3inst iters + avx2 (0.3t/s eval) * iq3_kt (0.3t/s eval) and renames * wip buggy iq4_KT * fix (0.22t/s eval) * naming and remove unused fn * cleanup * more cleanup * delete unused and noncompiling mmvq functions * Some performance tweaks * Slighty faster iq2_kt * port Trellis struct to iq3_kt, iq4_kt * oops untracked files --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/common.cuh | 7 ++ ggml/src/ggml-cuda/convert.cu | 128 ++++++++++++++++++++ ggml/src/ggml-cuda/dmmv.cu | 264 +++++++++++++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/mmvq.cu | 38 ++++++ ggml/src/ggml-cuda/mmvq.cuh | 1 + 5 files changed, 437 insertions(+), 1 deletion(-) (limited to 'ggml/src/ggml-cuda') diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a04a1929..896ba0df 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -564,6 +564,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI4_XS; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 5afe8c74..17604f1c 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -333,6 +333,101 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } +inline __device__ int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +float __device__ __forceinline__ trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + uint32_t s; + const half * h = (const half *)&s; + val = ka*val + kb; + s = (val & kmask) ^ km32; + return (float)(h[0]+h[1]); +} + +template +static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint16_t * ql = (const uint16_t *)x[i].ql; + uint32_t idx = ql[ib] + 4096; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; + for (int j = 0; j < 8; ++j) { + y[j] = dl * trellis_next(idx); + } +} + +template +static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint16_t * ql = (const uint16_t *)x[i].ql; + uint32_t idx = ql[ib] + 4096; + const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f; + uint8_t mask = 1 << (ib/4); + for (int j = 0; j < 8; ++j) { + y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); + } +} + +template +static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const float * dptr = (const float *)((const char *)vx + row * row_size); + float scale = dptr[0] * 31.75f * 1.01f; + float row_av = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const int64_t i = ii - (row*n_per_row)/QK_K; + + constexpr int kNumGroups = 64; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock; + const uint8_t * qh = ql + kNumGroups; + const int ib32 = ib/4; + const int ig = ib%4; + const int jj = ib32*8 + 2*ig; + uint32_t offset = shb[ib32] & 1 ? 4096 + 32768 : 4096; + uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + offset; + uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset; + int ls = ((shb[ib32] & 0xff) >> 1) - 64; + const float dl = scale * ls; + for (int j = 0; j < 4; ++j) { + y[j+0] = dl * trellis_next(idx1) + row_av; + y[j+4] = dl * trellis_next(idx2) + row_av; + } +} + template static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -968,6 +1063,27 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_ dequantize_block_iq2_xxs<<>>(vx, y); } +template +static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq2_kt<<>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row)); +} + +template +static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq3_kt<<>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row)); +} + +template +static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq4_kt<<>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row)); +} + template static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; @@ -1230,6 +1346,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_KT: + return dequantize_row_iq2_kt_cuda; + case GGML_TYPE_IQ3_KT: + return dequantize_row_iq3_kt_cuda; + case GGML_TYPE_IQ4_KT: + return dequantize_row_iq4_kt_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ2_S: @@ -1303,6 +1425,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_KT: + return dequantize_row_iq2_kt_cuda; + case GGML_TYPE_IQ3_KT: + return dequantize_row_iq3_kt_cuda; + case GGML_TYPE_IQ4_KT: + return dequantize_row_iq4_kt_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ2_S: diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 12738240..50e6458d 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "dmmv.cuh" #include "dequantize.cuh" #include "convert.cuh" @@ -8,6 +15,220 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +static __device__ __forceinline__ uint32_t trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + val = ka*val + kb; + return (val & kmask) ^ km32; +} + +static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) { + const half * h = (const half *)s; + s[0] = trellis_next(val1); + s[1] = trellis_next(val1); + s[2] = trellis_next(val2); + s[3] = trellis_next(val2); +#ifdef GGML_CUDA_F16 + bdot1 = __hfma2(y[ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); + bdot2 = __hfma2(y[64], {h[4]+h[5], h[6]+h[7]}, bdot2); +#else + bdot1.x += y[ 0].x * (float)(h[0] + h[1]); + bdot1.y += y[ 0].y * (float)(h[2] + h[3]); + bdot2.x += y[64].x * (float)(h[4] + h[5]); + bdot2.y += y[64].y * (float)(h[6] + h[7]); +#endif +} + +static __device__ __forceinline__ void trellis_accum_abs(uint8_t signs1, uint8_t signs2, uint8_t mask1, uint8_t mask2, + uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) { + const half * h = (const half *)s; + s[0] = trellis_next(val1); + s[1] = trellis_next(val1); + s[2] = trellis_next(val2); + s[3] = trellis_next(val2); +#ifdef GGML_CUDA_F16 + half h00 = __habs(h[0]+h[1]), h01 = __habs(h[2]+h[3]); + half h10 = __habs(h[4]+h[5]), h11 = __habs(h[6]+h[7]); + half2 h1 = {signs1 & mask1 ? -h00 : h00, signs2 & mask1 ? -h01 : h01}; + half2 h2 = {signs1 & mask2 ? -h10 : h10, signs2 & mask2 ? -h11 : h11}; + bdot1 = __hfma2(y[ 0], h1, bdot1); + bdot2 = __hfma2(y[64], h2, bdot2); +#else + bdot1.x += y[ 0].x * fabsf((float)(h[0] + h[1])) * (signs1 & mask1 ? -1 : 1); + bdot1.y += y[ 0].y * fabsf((float)(h[2] + h[3])) * (signs2 & mask1 ? -1 : 1); + bdot2.x += y[64].x * fabsf((float)(h[4] + h[5])) * (signs1 & mask2 ? -1 : 1); + bdot2.y += y[64].y * fabsf((float)(h[6] + h[7])) * (signs2 & mask2 ? -1 : 1); +#endif +} + +static __device__ __forceinline__ void trellis_accum(const dfloat2& dl1, const dfloat2& dl2, const dfloat2& bdot1, const dfloat2& bdot2, dfloat2& tmp) { +#ifdef GGML_CUDA_F16 + tmp = __hfma2(dl1, bdot1, tmp); + tmp = __hfma2(dl2, bdot2, tmp); +#else + tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x; + tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y; +#endif +} + +static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const float * dptr = (const float *)((const char *)vx + row*row_size); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + const int num_blocks_per_row = ncols / QK_K; + + dfloat2 tmp = {}; + + const int it = threadIdx.x/2; + const int ix = threadIdx.x%2; + + uint32_t s[4]; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); + const uint16_t * ql = (const uint16_t *)x[i].ql; + const dfloat scale1 = iq4k_values[x[i].scales[it/4] & 0xf]; + const dfloat scale2 = iq4k_values[x[i].scales[it/4] >> 4]; + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[it+ 0] + 4096; + uint32_t val2 = ql[it+16] + 4096; + for (int k = 0; k < 4; ++k) { + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); + } + trellis_accum(dl1, dl2, bdot1, bdot2, tmp); + } + + // sum up partial sums and write back result + tmp = warp_reduce_sum(tmp); + + if (threadIdx.x == 0) { + dst[row] = d * (float)(tmp.x + tmp.y); + } +} + +static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const float * dptr = (const float *)((const char *)vx + row*row_size); + const float d = *dptr * 31.75f * 1.015f; + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + const int num_blocks_per_row = ncols / QK_K; + + dfloat2 tmp = {}; + + const int it = threadIdx.x/2; + const int ix = threadIdx.x%2; + + uint32_t s[4]; + + uint8_t mask1 = 1 << (it/4); + uint8_t mask2 = mask1 << 4; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + const dfloat scale1 = (x[i].scales[it/4] & 0xf); + const dfloat scale2 = (x[i].scales[it/4] >> 4); + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[it+ 0] + 4096; + uint32_t val2 = ql[it+16] + 4096; + for (int k = 0; k < 4; ++k) { + trellis_accum_abs(qh[(8*it+2*k+0)%32], qh[(8*it+2*k+1)%32], mask1, mask2, val1, val2, s, y+k, bdot1, bdot2); + } + trellis_accum(dl1, dl2, bdot1, bdot2, tmp); + } + + // sum up partial sums and write back result + tmp = warp_reduce_sum(tmp); + + if (threadIdx.x == 0) { + dst[row] = d * (float)(tmp.x + tmp.y); + } +} + +static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { + + constexpr int kNumGroups = 64; + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const float * dptr = (const float *)((const char *)vx + row*row_size); + const float d = dptr[0] * 31.75f * 1.01f; + const float row_av = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + const int num_blocks_per_row = ncols / QK_K; + + dfloat2 tmp1 = {}; + dfloat2 tmp2 = {}; + + const int it = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; // 0 or 1 + + uint32_t s[4]; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t offset1 = 4096 + ((shb[it/4+0] & 1) << 15); + const uint32_t offset2 = 4096 + ((shb[it/4+4] & 1) << 15); + const dfloat scale1 = (int)((shb[it/4+0] & 0xff) >> 1) - 64; + const dfloat scale2 = (int)((shb[it/4+4] & 0xff) >> 1) - 64; + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + const uint32_t sh1 = shb[it/4+0] >> (8 + 6*(it%4)); + const uint32_t sh2 = shb[it/4+4] >> (8 + 6*(it%4)); + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[2*it+ 0] + ((qh[2*it+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + uint32_t val2 = ql[2*it+32] + ((qh[2*it+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + uint32_t val3 = ql[2*it+ 1] + ((qh[2*it+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + uint32_t val4 = ql[2*it+33] + ((qh[2*it+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + for (int k = 0; k < 2; ++k) { + trellis_accum(val1, val2, s, y+k+0, bdot1, bdot2); + trellis_accum(val3, val4, s, y+k+2, bdot1, bdot2); +#ifdef GGML_CUDA_F16 + tmp2 += y[k] + y[k+2] + y[k+64] + y[k+66]; +#else + tmp2.x += y[k].x + y[k+2].x + y[k+64].x + y[k+66].x; + tmp2.y += y[k].y + y[k+2].y + y[k+64].y + y[k+66].y; +#endif + } + trellis_accum(dl1, dl2, bdot1, bdot2, tmp1); + } + + // sum up partial sums and write back result + float tmp = d * (float)(tmp1.x + tmp1.y) + row_av * (float)(tmp2.x + tmp2.y); + tmp = warp_reduce_sum(tmp); + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -554,6 +775,36 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } +static void dequantize_mul_mat_vec_iq2_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_KT, ncols); + dequantize_mul_mat_vec_iq2_kt<<>>(vx, y, dst, ncols, nrows, row_size); +} + +static void dequantize_mul_mat_vec_iq3_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KT, ncols); + dequantize_mul_mat_vec_iq3_kt<<>>(vx, y, dst, ncols, nrows, row_size); +} + +static void dequantize_mul_mat_vec_iq4_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KT, ncols); + dequantize_mul_mat_vec_iq4_kt<<>>(vx, y, dst, ncols, nrows, row_size); +} + static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; @@ -615,7 +866,8 @@ void ggml_cuda_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_IQ2_KT || src0->type == GGML_TYPE_IQ3_KT || src0->type == GGML_TYPE_IQ4_KT; if (src1_convert_f16) { src1_dfloat = src1_dfloat_a.alloc(ne00); @@ -646,6 +898,15 @@ void ggml_cuda_op_dequantize_mul_mat_vec( case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_IQ2_KT: + dequantize_mul_mat_vec_iq2_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_KT: + dequantize_mul_mat_vec_iq3_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_KT: + dequantize_mul_mat_vec_iq4_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q3_K: dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; @@ -679,5 +940,6 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || + src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT || src0_type == GGML_TYPE_F16; } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 14fe2547..d0477835 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -618,3 +618,41 @@ void ggml_cuda_op_mul_mat_vec_q_id( GGML_UNUSED(src1_ddf_i); } + +bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ3_S: + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 525c6bc0..d17765f1 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); +bool ggml_cuda_mmvq_type_supported(ggml_type src0_type); void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, -- cgit v1.2.3