diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-25 13:08:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-25 13:08:43 +0200 |
commit | 6b968f38946117552ffed300771c44ba9b39d3e4 (patch) | |
tree | dc6b0df69f31ea77d9941d6798a4ef411c688080 /ggml/src/ggml-cuda | |
parent | 9114078959b404899fd67e1af45f0dcbee51b47f (diff) |
Bitnet changes (#106)
* Adapting iq2_bn to work without separate scale tensors
Why? It is becoming burdensome to maintain the special Bitnet
conversion in convert_hf_to_gguf.py, so I thnk it is better
to make iq1_bn and iq2_bn just work with the mainline
conversion script (which does not generate scales).
* Adapting iq1_bn to work without separate scale tensors
* Adapting iq2_bn: CUDA dequantize
* Adapting iq2_bn: CUDA works
* Adapting iq1_bn: CUDA works
* Adapting iq1_bn, iq2_bn: NEON
* Adapting iq1_bn, iq2_bn: Metal
Dequantize works, but there is still something wrong
with the dot products.
* WIP
Absoolutely don't see what is wrong with the iq1_bn and iq2_bn
vector dot product kernels.
* Remove iq1_tn and iq2_tn - Part 1
Now that iq1_bn and iq2_bn have per row scales, there is no
reason to also have iq1_tn and iq2_tn.
* Remove iq1_tn and iq2_tn - Part 2
* Bitnet: use the standard llm_build_kv to build self attention
My main motivation was to enable FA. But FA does not work anyway
because head size is 100 for the Botnet ternary models
(and I had forgotten this little detail).
* Revert "Avoid rebuild of GGML graph for each token (#98)"
This reverts commit f2d315b46f7aacc7df4b86bd8acba387b30e11ca.
As far as I can tell, the commit breaks Metal TG.
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/binbcast.cu | 2 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 14 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 147 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/fattn.cu | 5 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cu | 84 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/iqk_mmvq.cuh | 11 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 22 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/vecdotq.cuh | 89 |
8 files changed, 96 insertions, 278 deletions
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 62d115f1..5abbd43c 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -288,7 +288,7 @@ static void scale_f32_cuda_l(const float * x, float * dst, const void * data, co scale_f32_l<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, data, k); } -void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +static void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a5658a24..2eba527f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -474,13 +474,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> { }; template<> -struct ggml_cuda_type_traits<GGML_TYPE_IQ1_TN> { - static constexpr int qk = QK_IQ1BN; - static constexpr int qr = QR1_BN; - static constexpr int qi = QI1_BN; -}; - -template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> { static constexpr int qk = QK_IQ1BN; static constexpr int qr = QR1_BN; @@ -488,13 +481,6 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> { }; template<> -struct ggml_cuda_type_traits<GGML_TYPE_IQ2_TN> { - static constexpr int qk = QK_K; - static constexpr int qr = QR2_K; - static constexpr int qi = QI2_K; -}; - -template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> { static constexpr int qk = QK4_NL; static constexpr int qr = QR4_NL; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e9d15b5d..b9baee1b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -184,30 +184,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t } template<typename dst_t> -static __global__ void dequantize_block_iq2_tn(const void * __restrict__ vx, dst_t * __restrict__ yy, - int64_t n_per_row, int64_t row_size) { - - int64_t ii = blockIdx.x; - int64_t row = (QK_K * ii) / n_per_row; - const char * cx = (const char *)vx + row * row_size; - float d = *(const float *)cx; - const block_iq2_tn * x = (const block_iq2_tn *)(cx + sizeof(float)); - int64_t i = ii - (row*n_per_row)/QK_K; - - const int64_t tid = threadIdx.x; - const int64_t n = tid/32; - const int64_t l = tid - 32*n; - - const uint8_t q = x[i].qs[32*n + l]; - dst_t * y = yy + ii*QK_K + 128*n; - - y[l+ 0] = d * ((q >> 0) & 3) - d; - y[l+32] = d * ((q >> 2) & 3) - d; - y[l+64] = d * ((q >> 4) & 3) - d; - y[l+96] = d * ((q >> 6) & 3) - d; -} - -template<typename dst_t> static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -481,101 +457,72 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ } template<typename dst_t> -static __global__ void dequantize_block_iq1_tn(const void * __restrict__ vx, dst_t * __restrict__ yy, - int64_t n_per_row, int64_t row_size) { - - int64_t ii = blockIdx.x; - int64_t row = (QK_K * ii) / n_per_row; - const char * cx = (const char *)vx + row * row_size; - float scale = *(const half *)cx; - const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(half)); - - static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - -//#define COMPUTE_VS(v) 3*v >> 8 -#define COMPUTE_VS(v) (v + (v >> 1)) >> 7 +static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, + int64_t n_per_row, int64_t row_size, int64_t nrows) { + int64_t ii = 256*blockIdx.x; const int tid = threadIdx.x; const int il = tid/4; // 0...7 const int ib = tid%4; // 0...3 - dst_t * y = yy + ii*QK_K + 64*ib + 8*il; - const int i16 = il/2; - int64_t i = QK_K/QK_IQ1BN * (ii - (row*n_per_row)/QK_K) + ib; - uint8_t q = x[i].ql[3*i16+2*(il%2)]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = COMPUTE_VS(v); - y[2*(il%2)+j] = scale*(vs - 1); - } - q = x[i].ql[3*i16+1]; - for (int j = 0; j < 2; ++j) { - uint8_t v = k_mult[3*(il%2)+j]*q; - int8_t vs = COMPUTE_VS(v); - y[5*(1-(il%2))+j] = scale*(vs-1); - } - uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; - int8_t vs = COMPUTE_VS(v); - y[7] = scale*(vs - 1); + dst_t * y = yy + ii + 64*ib + 8*il; -#undef COMPUTE_VS -} - -template<typename dst_t> -static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { - - const int64_t ii = blockIdx.x; - const block_iq1_bn * x = (const block_iq1_bn *) vx; + int64_t row = ii / n_per_row; + if (row >= nrows) return; + const char * cx = (const char *)vx + row * row_size; + half d16; memcpy(&d16, cx, sizeof(d16)); // in case not 2-byte aligned + float d = d16; + const block_iq1_bn * x = (const block_iq1_bn *)(cx + sizeof(d16)); + ii -= row*n_per_row; + int64_t i = ii/QK_IQ1BN + ib; static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; //#define COMPUTE_VS(v) 3*v >> 8 #define COMPUTE_VS(v) (v + (v >> 1)) >> 7 - const int tid = threadIdx.x; - const int il = tid/4; // 0...7 - const int ib = tid%4; // 0...3 - dst_t * y = yy + ii*QK_K + 64*ib + 8*il; - int64_t i = QK_K/QK_IQ1BN * ii + ib; - if (i >= nb64) return; const int i16 = il/2; uint8_t q = x[i].ql[3*i16+2*(il%2)]; for (int j = 0; j < 5; ++j) { uint8_t v = k_mult[j]*q; int8_t vs = COMPUTE_VS(v); - y[2*(il%2)+j] = vs - 1; + y[2*(il%2)+j] = d*(vs - 1); } q = x[i].ql[3*i16+1]; for (int j = 0; j < 2; ++j) { uint8_t v = k_mult[3*(il%2)+j]*q; int8_t vs = COMPUTE_VS(v); - y[5*(1-(il%2))+j] = vs-1; + y[5*(1-(il%2))+j] = d*(vs-1); } uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; int8_t vs = COMPUTE_VS(v); - y[7] = vs - 1; + y[7] = d*(vs - 1); #undef COMPUTE_VS } template<typename dst_t> -static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { - - const int64_t ii = blockIdx.x; - const block_iq2_bn * x = (const block_iq2_bn *) vx; +static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size, int64_t nrows) { + int64_t ii = 256*blockIdx.x; const int64_t tid = threadIdx.x; int64_t ib64 = tid%4; // 0...3 int64_t il = tid/4; // 0...7 - dst_t * y = yy + 256*ii + 64*ib64 + 2*il; - int64_t i = 256/QK_IQ1BN * ii + ib64; - if (i >= nb64) return; - const float m = -1; + dst_t * y = yy + ii + 64*ib64 + 2*il; + + int64_t row = ii / n_per_row; + if (row >= nrows) return; + const char * cx = (const char *)vx + row * row_size; + float d = *(const float *)cx; + const block_iq2_bn * x = (const block_iq2_bn *)(cx + sizeof(float)); + ii -= row*n_per_row; + int64_t i = ii/QK_IQ1BN + ib64; + const float m = -d; auto qs = x[i].qs + 2*il; for (int j = 0; j < 2; ++j) { - y[j+ 0] = ((qs[j] >> 0) & 3) + m; - y[j+16] = ((qs[j] >> 2) & 3) + m; - y[j+32] = ((qs[j] >> 4) & 3) + m; - y[j+48] = ((qs[j] >> 6) & 3) + m; + y[j+ 0] = d * ((qs[j] >> 0) & 3) + m; + y[j+16] = d * ((qs[j] >> 2) & 3) + m; + y[j+32] = d * ((qs[j] >> 4) & 3) + m; + y[j+48] = d * ((qs[j] >> 6) & 3) + m; } } @@ -857,14 +804,6 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t n } template<typename dst_t> -static void dequantize_row_iq2_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { - const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_TN, n_per_row); - const int nb = (k + 255) / 256; - dequantize_block_iq2_tn<<<nb, 64, 0, stream>>>(vx, y, n_per_row, row_size); -} - -template<typename dst_t> static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int nb = k / QK_K; @@ -975,25 +914,17 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t template<typename dst_t> static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int nb64 = k / QK_IQ1BN; - const int nb = (k + 255) / 256; - dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64); -} - -template<typename dst_t> -static void dequantize_row_iq1_tn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { - const int64_t k = nrows * n_per_row; - const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_TN, n_per_row); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row); const int nb = (k + 255) / 256; - dequantize_block_iq1_tn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size); + dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows); } template<typename dst_t> static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; - const int nb64 = k / QK_IQ1BN; + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row); const int nb = (k + 255) / 256; - dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64); + dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size, nrows); } template<typename dst_t> @@ -1157,8 +1088,6 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; - case GGML_TYPE_IQ2_TN: - return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -1181,8 +1110,6 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: return dequantize_row_iq1_bn_cuda; - case GGML_TYPE_IQ1_TN: - return dequantize_row_iq1_tn_cuda; case GGML_TYPE_IQ2_BN: return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: @@ -1232,8 +1159,6 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; - case GGML_TYPE_IQ2_TN: - return dequantize_row_iq2_tn_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: @@ -1256,8 +1181,6 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: return dequantize_row_iq1_bn_cuda; - case GGML_TYPE_IQ1_TN: - return dequantize_row_iq1_tn_cuda; case GGML_TYPE_IQ2_BN: return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1dfb24f9..c15d6c81 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -38,6 +38,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -63,6 +64,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); // break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -86,6 +88,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -114,6 +117,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } @@ -141,6 +145,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); GGML_ABORT("fatal error"); break; } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index dec54b5e..795243e7 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -626,35 +626,12 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1( } -#define VDR_IQ2_TN_Q8_1_MMVQ 1 -#define VDR_IQ2_TN_Q8_1_MMQ 4 - -static __device__ __forceinline__ float vec_dot_iq2_tn_q8_1( +__device__ __forceinline__ float vec_dot_iq1_bn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - float scale = *(const float *)vbq; - const block_iq2_tn * bq2 = (const block_iq2_tn *)((const char *)vbq + sizeof(float)) + kbx; - - const int bq8_offset = QR2_K * (iqs / QI8_1); - - const uint16_t * q16 = (const uint16_t *)bq2->qs + 2*iqs; - int v = q16[0] | (q16[1] << 16); - - float sumf = 0; - for (int i = 0; i < QR2_K; ++ i) { - int u = *((const int *)bq8_1[bq8_offset + i].qs + iqs % QI8_1); - float d8 = __low2float(bq8_1[bq8_offset + i].ds); - sumf += d8 * (ggml_cuda_dp4a(v & 0x03030303, u, 0) - ggml_cuda_dp4a(0x01010101, u, 0)); - v >>= 2; - } - return scale * sumf; -} - -static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - - float scale = *(const half *)vbq; - const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(half)) + kbx; + half d16; memcpy(&d16, vbq, sizeof(d16)); + float scale = d16; + const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(d16)) + kbx; static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; @@ -699,7 +676,48 @@ static __device__ __forceinline__ float vec_dot_iq1_tn_q8_1( q8++; } #endif - return __low2float(bq8_1[iqs].ds) * scale * sumi; + return scale * __low2float(bq8_1[iqs].ds) * sumi; +} + +__device__ __forceinline__ float vec_dot_iq2_bn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + float scale = *(const float *)vbq; + const block_iq2_bn * bq2 = (const block_iq2_bn *)((const char *)vbq + sizeof(float)) + kbx; + + // iqs is 0 or 1 + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + auto qs = (const uint16_t *)bq2->qs + 4*iqs; + auto q8l = (const int *)bq8_1[0].qs + 2*iqs; + auto q8h = (const int *)bq8_1[1].qs + 2*iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + for (int j = 0; j < 2; ++j) { + int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16); + int vh = vl >> 4; + sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1); + sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2); + sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3); + sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); +#else + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + auto q8l = bq8_1[0].qs + 8*iqs; + auto q8h = bq8_1[1].qs + 8*iqs; + auto qs = bq2->qs + 8*iqs; + for (int j = 0; j < 8; ++j) { + sumi1 += q8l[j+ 0] * (qs[j] & 0x03); + sumi2 += q8l[j+16] * (qs[j] & 0x0c); + sumi3 += q8h[j+ 0] * (qs[j] & 0x30); + sumi4 += q8h[j+16] * (qs[j] & 0xc0); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); + return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y); +#endif + return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y); } } // namespace @@ -760,16 +778,14 @@ void mul_mat_vec_iq6_k_q8_1_cuda( iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -void mul_mat_vec_iq2_tn_q8_1_cuda( +void mul_mat_vec_iq1_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN, 1, vec_dot_iq1_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -void mul_mat_vec_iq1_tn_q8_1_cuda( +void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ1_TN, 1, vec_dot_iq1_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN, 1, vec_dot_iq2_bn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 0678c026..1693a73a 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -20,23 +20,22 @@ void mul_mat_vec_iq6_k_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq2_tn_q8_1_cuda( +void mul_mat_vec_iq4_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq1_tn_q8_1_cuda( +void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq4_ks_q8_1_cuda( +void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq4_kss_q8_1_cuda( +void mul_mat_vec_iq1_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); -void mul_mat_vec_iq2_ks_q8_1_cuda( +void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); - diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 107caf45..cdf13533 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -22,8 +22,6 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 : - type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : @@ -325,20 +323,6 @@ static void mul_mat_vec_iq1_m_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } -static void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -438,12 +422,6 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ2_BN: mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; - case GGML_TYPE_IQ2_TN: - mul_mat_vec_iq2_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_TN: - mul_mat_vec_iq1_tn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 7baabb7a..e9af29b9 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1117,95 +1117,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } -static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; - - static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - - // iqs is 0 or 1 - - int sumi = 0; -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const int * q8 = (const int *)bq8_1[iqs].qs; - int val[4]; - for (int l = 0; l < 2; ++l) { - int8_t * a = (int8_t *)val; - const int i16 = 2*iqs + l; - for (int k = 0; k < 3; ++k) { - uint8_t q = bq1->ql[3*i16+k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; - *a++ = vs-1; - } - } - uint8_t v = k_mult[i16]*bq1->extra; - int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; - *a++ = vs-1; - sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi)))); - } -#else - const int8_t * q8 = bq8_1[iqs].qs; - for (int l = 0; l < 2; ++l) { - const int i16 = 2*iqs + l; - for (int k = 0; k < 3; ++k) { - uint8_t q = bq1->ql[3*i16+k]; - for (int j = 0; j < 5; ++j) { - uint8_t v = k_mult[j]*q; - int8_t vs = (v + (v >> 1)) >> 7; - sumi += q8[j]*(vs - 1); - } - q8 += 5; - } - uint8_t v = k_mult[i16]*bq1->extra; - int8_t vs = (v + (v >> 1)) >> 7; - sumi += q8[0]*(vs - 1); - q8++; - } -#endif - return __low2float(bq8_1[iqs].ds) * sumi; -} - -static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx; - - // iqs is 0 or 1 - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - auto qs = (const uint16_t *)bq2->qs + 4*iqs; - auto q8l = (const int *)bq8_1[0].qs + 2*iqs; - auto q8h = (const int *)bq8_1[1].qs + 2*iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - for (int j = 0; j < 2; ++j) { - int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16); - int vh = vl >> 4; - sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1); - sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2); - sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3); - sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4); - } - auto d8l = __half22float2(bq8_1[0].ds); - auto d8h = __half22float2(bq8_1[1].ds); - return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y; -#else - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - auto q8l = bq8_1[0].qs + 8*iqs; - auto q8h = bq8_1[1].qs + 8*iqs; - auto qs = bq2->qs + 8*iqs; - for (int j = 0; j < 8; ++j) { - sumi1 += q8l[j+ 0] * (qs[j] & 0x03); - sumi2 += q8l[j+16] * (qs[j] & 0x0c); - sumi3 += q8h[j+ 0] * (qs[j] & 0x30); - sumi4 += q8h[j+16] * (qs[j] & 0xc0); - } - auto d8l = __half22float2(bq8_1[0].ds); - auto d8h = __half22float2(bq8_1[1].ds); - return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y; -#endif -} - static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; const int8_t * q0_8 = (const int8_t *) &q0_32; |