diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-10-26 15:13:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-26 15:13:45 +0200 |
commit | 3805c84686f40fc4423d45308cab6adac2eafdd4 (patch) | |
tree | d753243884c6d45daf0c2cfd6c0c8e8afb625b00 | |
parent | f7b05a09ddb2b2579f6301a6223d894f5b97c494 (diff) |
Improve Bitnet PP on Metal (#108)
iq1_bn goes from 702 t/s to 716 t/s
iq2_bn goes from 714 t/s to 743 t/s
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-metal.metal | 66 |
1 files changed, 50 insertions, 16 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 287d8563..e8f742fc 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -7473,11 +7473,11 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & template <typename type4x4> -void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { +void dequantize_iq1_bn(half d, device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1}; - constexpr half k_values[3] = {-1.h, 0.h, 1.h}; + const half k_values[3] = {-d, 0.h, d}; for (int k = 0; k < 3; ++k) { uint16_t q = xb->ql[3*il + k]; @@ -7496,15 +7496,22 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 } template <typename type4x4> -void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4 & reg) { +void dequantize_iq2_bn(half d, device const block_iq2_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 constexpr half k_scale[4] = {1.h, 0.25h, 0.0625h, 0.015625h}; - constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0}; - const half d = k_scale[il]; - uint8_t mask = k_mask[il]; + const half db = d * k_scale[il]; + const uint32_t mask = 0x03030303 << 2*il; - for (int j = 0; j < 16; ++j) { - reg[j/4][j%4] = d * (xb->qs[j] & mask) - 1; + device const uint32_t * qs = (device const uint32_t *)xb->qs; + uint32_t aux32; + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + + for (int j = 0; j < 4; ++j) { + aux32 = qs[j] & mask; + reg[j][0] = db * aux8[0] - d; + reg[j][1] = db * aux8[1] - d; + reg[j][2] = db * aux8[2] - d; + reg[j][3] = db * aux8[3] - d; } } @@ -7855,12 +7862,12 @@ struct DefaultDequantizer { }; template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&), bool may_not_be_aligned = false> -struct DequantizerRS{ +struct DequantizerRS { using type4x4 = T4x4; DequantizerRS(device const char * cx, short il = 0) : il(il) { if (may_not_be_aligned) { thread char * aux = (thread char *)&d; - for (int i = 0; i < sizeof(d); ++i) aux[i] = cx[i]; + for (int i = 0; i < int(sizeof(d)); ++i) aux[i] = cx[i]; } else { d = *(device const Scale *)cx; } @@ -7883,6 +7890,33 @@ struct DequantizerRS{ Scale d; }; +template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false> +struct DequantizerRSBN { + using type4x4 = T4x4; + DequantizerRSBN(device const char * cx, short il = 0) : il(il) { + if (may_not_be_aligned) { + thread char * aux = (thread char *)&d; + for (int i = 0; i < int(sizeof(d)); ++i) aux[i] = cx[i]; + } else { + d = *(device const Scale *)cx; + } + x = (device const Block *)(cx + sizeof(Scale)); + } + inline void convert(thread T4x4& t) const { + dequantize(d, x, il, t); + } + inline void convert(int64_t ind, thread T4x4& t) { + dequantize(d, x + ind/nl, ind%nl, t); + } + inline void next() { + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + } + device const Block * x; + short il; + Scale d; +}; + // each block_q contains 16*nl weights template<typename T, typename simdgroup_T8x8, typename Dequantizer> kernel void kernel_mul_mm(device const uchar * src0, @@ -8251,8 +8285,8 @@ template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq4_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_k, QK_NL, dequantize_iq4_k>; template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq5_k, QK_NL, dequantize_iq5_k>; template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>; -template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; -template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; +template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRSBN<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; +template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRSBN<float4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>; template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; @@ -8294,8 +8328,8 @@ template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>>; template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>>; template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; -template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; -template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; +template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; +template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>; template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; @@ -8334,8 +8368,8 @@ template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>; template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>; template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>; -template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; -template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; +template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>; +template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>; template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>; template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; |