diff options
-rw-r--r-- | examples/quantize/quantize.cpp | 2 | ||||
-rw-r--r-- | ggml-metal.metal | 35 | ||||
-rw-r--r-- | llama.cpp | 2 |
3 files changed, 21 insertions, 18 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index a5ffb2b2..0a61a083 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.75 bpw quantization (Bitnet)", }, + { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, diff --git a/ggml-metal.metal b/ggml-metal.metal index c9439727..34e77728 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5053,7 +5053,7 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[8]; + float4 yl[2]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_IQ1BN / 32); @@ -5066,12 +5066,12 @@ void kernel_mul_mv_iq1_bn_f32_impl( uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; + const float values[3] = {-1.f, 0.f, 1.f}; + for (int ib32 = ix; ib32 < nb32; ib32 += 8) { - float sumy = 0.f; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i]; sumy += yl[i]; - } + yl[0] = {y4[0], y4[4], y4[2], y4[6]}; + yl[1] = {y4[1], y4[5], y4[3], y4[7]}; const int ibl = ib32 / (QK_IQ1BN / 32); const int ib = ib32 % (QK_IQ1BN / 32); @@ -5084,15 +5084,15 @@ void kernel_mul_mv_iq1_bn_f32_impl( for (int row = 0; row < N_DST; row++) { uint8_t signs = extra[0] >> (4*ib + ir); - float acc = 0.f; uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)]; uint32_t v32 = v | (v << 12); - aux32[0] = v32 & 0x03030303; aux32[1] = v32 & 0x0c0c0c0c; - acc = yl[0] * aux8[0] + yl[4] * aux8[1] + yl[2]*aux8[2] + yl[6]*aux8[3]; - acc += (yl[1] * aux8[4] + yl[5] * aux8[5] + yl[3]*aux8[6] + yl[7]*aux8[7]) * 0.25f; + aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 2) & 0x03030303; + float4 acc4 = yl[0] * float4{values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]} + + yl[1] * float4{values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]}; + float acc = acc4[0] + acc4[1] + acc4[2] + acc4[3]; - sumf[row] += (signs & 1 ? sumy-acc : acc-sumy); + sumf[row] += (signs & 1 ? -acc : acc); extra += nb*sizeof(block_iq1_bn); ql += nb*sizeof(block_iq1_bn); @@ -5956,8 +5956,8 @@ template <typename type4x4> void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 uint8_t gs = xb->extra >> 2*il; - const float d1 = gs & 1 ? -1 : 1; - const float d2 = gs & 2 ? -1 : 1; + const half d1 = gs & 1 ? -1.h : 1.h; + const half d2 = gs & 2 ? -1.h : 1.h; uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; @@ -5966,12 +5966,15 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 uint32_t aux32; thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + const half values[3] = {-1.h, 0.h, 1.h}; + +#pragma unroll(4) for (int i = 0; i < 4; ++i) { aux32 = (v >> 2*i) & 0x03030303; - reg[0][i] = d1*aux8[0] - d1; - reg[1][i] = d1*aux8[1] - d1; - reg[2][i] = d2*aux8[2] - d2; - reg[3][i] = d2*aux8[3] - d2; + reg[0][i] = d1*values[aux8[0]]; + reg[1][i] = d1*values[aux8[1]]; + reg[2][i] = d2*values[aux8[2]]; + reg[3][i] = d2*values[aux8[3]]; } } @@ -4130,7 +4130,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_M :return "IQ1_M - 1.75 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_BN :return "IQ1_BN - 1.75 bpw Bitnet"; + case LLAMA_FTYPE_MOSTLY_IQ1_BN :return "IQ1_BN - 1.625 bpw Bitnet"; case LLAMA_FTYPE_MOSTLY_IQ2_BN :return "IQ2_BN - 2.00 bpw Bitnet"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; |