diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-16 15:56:32 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:51 +0300 |
commit | 0f53bc30bbd4949be80562448419b4bbccd9490a (patch) | |
tree | caadf89e3776d78b02f7d560e6d961305c4f14df /ggml-cuda | |
parent | f20b28558bdd20454ce891d36db5f37de819025a (diff) |
bitnet: CUDA, scalar, AVX2
Diffstat (limited to 'ggml-cuda')
-rw-r--r-- | ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml-cuda/convert.cu | 37 | ||||
-rw-r--r-- | ggml-cuda/mmvq.cu | 11 | ||||
-rw-r--r-- | ggml-cuda/vecdotq.cuh | 46 |
4 files changed, 101 insertions, 0 deletions
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 5bd24ebe..c21513e6 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -623,6 +623,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> { + static constexpr int qk = QK_IQ1BN; + static constexpr int qr = QR1_BN; + static constexpr int qi = QI1_BN; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> { static constexpr int qk = QK4_NL; static constexpr int qr = QR4_NL; diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index c0a44470..293a47f1 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -420,6 +420,32 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ } template<typename dst_t> +static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) { + + const int64_t ii = blockIdx.x; + const block_iq1_bn * x = (const block_iq1_bn *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + int64_t ib = tid%8; // 0...7 + dst_t * y = yy + ii*QK_K + 32*ib + 8*il; + int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32); + if (i >= nb64) return; + ib = ib%(QK_IQ1BN/32); + typedef union { float f; uint32_t i; } scale_t; + scale_t s; + uint8_t u = x[i].extra & 0xff; + s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + const float dl = x[i].extra & (1 << (4*ib + il + 8)) ? -s.f : s.f; + uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00); + const uint16_t gp = iq1bn_grid_u16[idx]; + for (int j = 0; j < 8; ++j) { + y[j] = dl * (((gp >> 2*j) & 3) - 1); + } +} + + +template<typename dst_t> static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -564,6 +590,13 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t } template<typename dst_t> +static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb64 = k / QK_IQ1BN; + const int nb = (k + 255) / 256; + dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64); +} + +template<typename dst_t> static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); @@ -625,6 +658,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ1_BN: + return dequantize_row_iq1_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: @@ -672,6 +707,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq1_s_cuda; case GGML_TYPE_IQ1_M: return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ1_BN: + return dequantize_row_iq1_bn_cuda; case GGML_TYPE_IQ4_NL: return dequantize_row_iq4_nl_cuda; case GGML_TYPE_IQ4_XS: diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index e8d15716..ea7b6328 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -20,6 +20,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : + type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : @@ -307,6 +308,13 @@ static void mul_mat_vec_iq1_m_q8_1_cuda( mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_iq1_bn_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -397,6 +405,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_IQ1_M: mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ1_BN: + mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ4_NL: mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 3b12d656..55dbba3c 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -1074,6 +1074,52 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1)); } +static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; + + typedef union { float f; uint32_t i; } scale_t; + scale_t s; + uint8_t u = bq1->extra & 0xff; + s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + uint8_t extra = bq1->extra >> (8 + 4*iqs); + int sumi = 0; +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const int * q8 = (const int *)bq8_1[iqs].qs; + //const int minus = 0xffffffff; + for (int l = 0; l < 4; ++l) { + int sign = extra & (1 << l) ? -1 : 1; + uint16_t val = iq1bn_grid_xxx[bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00)]; + uint8_t vp = val & 0xff, vm = val >> 8; + int32_t vp1 = __vcmpeq4(((vp & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + int32_t vp2 = __vcmpeq4(((vp >> 4) * 0x01010101) & 0x08040201, 0x08040201); + int32_t vm1 = __vcmpeq4(((vm & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + int32_t vm2 = __vcmpeq4(((vm >> 4) * 0x01010101) & 0x08040201, 0x08040201); + sumi += (__dp4a(q8[2*l+0], vm1, __dp4a(q8[2*l+1], vm2, 0)) - __dp4a(q8[2*l+0], vp1, __dp4a(q8[2*l+1], vp2, 0)))*sign; + //int32_t vp1 = __vcmpeq4(((vp & 0xf) * 0x01010101) & 0x08040201, 0x08040201) & q8[2*l+0]; + //int32_t vp2 = __vcmpeq4(((vp >> 4) * 0x01010101) & 0x08040201, 0x08040201) & q8[2*l+1]; + //int32_t vm1 = __vcmpeq4(((vm & 0xf) * 0x01010101) & 0x08040201, 0x08040201) & q8[2*l+0]; + //int32_t vm2 = __vcmpeq4(((vm >> 4) * 0x01010101) & 0x08040201, 0x08040201) & q8[2*l+1]; + //int32_t v1 = __vsubss4(vp1, vm1); + //int32_t v2 = __vsubss4(vp2, vm2); + //sumi += __dp4a(v1, 0x01010101, __dp4a(v2, 0x01010101, 0))*sign; + } +#else + const int8_t * q8 = bq8_1[iqs].qs; + for (int l = 0; l < 4; ++l) { + uint16_t val = iq1bn_grid_u16[bq1->ql[4*iqs + l] | ((bq1->qh[2*iqs + l/2] << (8 - 4*(l%2))) & 0x0f00)]; + int s1 = 0, s2 = 0; + for (int j = 0; j < 8; ++j) { + s1 += q8[j] * ((val >> 2*j) & 3); + s2 += q8[j]; + } + sumi += extra & (1 << l) ? s2 - s1 : s1 - s2; + q8 += 8; + } +#endif + return s.f * __low2float(bq8_1[iqs].ds) * sumi; +} + #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, int & val1, int & val2) { |