diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-18 13:56:16 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | 01ea9a862d4afb73f936de8f4ef46401ce11b596 (patch) | |
tree | 892dd475ddf47c9ba401859b285b4b517d36135e | |
parent | 2998ca9b14d9b2d4b184cf6d923cea8b07a6320a (diff) |
Bitnet(2.25 bpw): CUDA
We get PP-512 = 9600 t/s, TG-128 = 234 t/s
(but we need to use 8 CPU threads, else results are lower,
so clearly there is something being computed on the CPU).
PP-512 is very close to PP-512(fp16) = 9800 t/s
-rw-r--r-- | ggml-cuda.cu | 2 | ||||
-rw-r--r-- | ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml-cuda/convert.cu | 34 | ||||
-rw-r--r-- | ggml-cuda/mmvq.cu | 11 | ||||
-rw-r--r-- | ggml-cuda/vecdotq.cuh | 40 |
5 files changed, 93 insertions, 1 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3874e20..d51663c8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2757,7 +2757,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S || a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ1_BN) { + a_type == GGML_TYPE_IQ1_BN || a_type == GGML_TYPE_IQ2_BN) { if (b->ne[1] == 1 && ggml_nrows(b) > 1) { return false; } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index c21513e6..892fd5a6 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -630,6 +630,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> { + static constexpr int qk = QK_IQ1BN; + static constexpr int qr = QR1_BN; + static constexpr int qi = QI1_BN; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> { static constexpr int qk = QK4_NL; static constexpr int qr = QR4_NL; diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 293a47f1..ec78549c 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -444,6 +444,29 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst } } +template<typename dst_t> +static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { + + const int64_t ii = blockIdx.x; + const block_iq2_bn * x = (const block_iq2_bn *) vx; + + const int64_t tid = threadIdx.x; + int64_t ib64 = tid%4; // 0...3 + int64_t il = tid/4; // 0...7 + dst_t * y = yy + 256*ii + 64*ib64 + 2*il; + int64_t i = 256/QK_IQ1BN * ii + ib64; + if (i >= nb64) return; + const float d = x[i].d; + const float m = -d; + auto qs = x[i].qs + 2*il; + for (int j = 0; j < 2; ++j) { + y[j+ 0] = d * ((qs[j] >> 0) & 3) + m; + y[j+16] = d * ((qs[j] >> 2) & 3) + m; + y[j+32] = d * ((qs[j] >> 4) & 3) + m; + y[j+48] = d * ((qs[j] >> 6) & 3) + m; + } +} + template<typename dst_t> static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -597,6 +620,13 @@ static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t } template<typename dst_t> +static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb64 = k / QK_IQ1BN; + const int nb = (k + 255) / 256; + dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64); +} + +template<typename dst_t> static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); @@ -660,6 +690,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: return dequantize_row_iq1_bn_cuda; + case GGML_TYPE_IQ2_BN: + return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: @@ -709,6 +741,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_m_cuda; case GGML_TYPE_IQ1_BN: return dequantize_row_iq1_bn_cuda; + case GGML_TYPE_IQ2_BN: + return dequantize_row_iq2_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index ea7b6328..b0778c73 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -21,6 +21,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 : + type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : @@ -315,6 +316,13 @@ static void mul_mat_vec_iq1_bn_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_iq2_bn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -408,6 +416,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ1_BN: mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ2_BN: + mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 55dbba3c..acab3865 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -1120,6 +1120,46 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( return s.f * __low2float(bq8_1[iqs].ds) * sumi; } +// TODO +static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx; + + // iqs is 0 or 1 + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + auto qs = (const uint16_t *)bq2->qs + 4*iqs; + auto q8l = (const int *)bq8_1[0].qs + 2*iqs; + auto q8h = (const int *)bq8_1[1].qs + 2*iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + for (int j = 0; j < 2; ++j) { + int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16); + int vh = vl >> 4; + sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1); + sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2); + sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3); + sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); + return (float)bq2->d * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y); +#else + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + auto q8l = bq8_1[0].qs + 8*iqs; + auto q8h = bq8_1[1].qs + 8*iqs; + auto qs = bq2->qs + 8*iqs; + for (int j = 0; j < 8; ++j) { + sumi1 += q8l[j+ 0] * (qs[j] & 0x03); + sumi2 += q8l[j+16] * (qs[j] & 0x0c); + sumi3 += q8h[j+ 0] * (qs[j] & 0x30); + sumi4 += q8h[j+16] * (qs[j] & 0xc0); + } + auto d8l = __half22float2(bq8_1[0].ds); + auto d8h = __half22float2(bq8_1[1].ds); + return (float)bq2->d * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y); +#endif +} + #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, int & val1, int & val2) { |