summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/quantize/quantize.cpp2
-rw-r--r--ggml-metal.metal35
-rw-r--r--llama.cpp2
3 files changed, 21 insertions, 18 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index a5ffb2b2..0a61a083 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -26,7 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
- { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.75 bpw quantization (Bitnet)", },
+ { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", },
{ "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
diff --git a/ggml-metal.metal b/ggml-metal.metal
index c9439727..34e77728 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -5053,7 +5053,7 @@ void kernel_mul_mv_iq1_bn_f32_impl(
device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float yl[8];
+ float4 yl[2];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_IQ1BN / 32);
@@ -5066,12 +5066,12 @@ void kernel_mul_mv_iq1_bn_f32_impl(
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
+ const float values[3] = {-1.f, 0.f, 1.f};
+
for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
- float sumy = 0.f;
- for (int i = 0; i < 8; ++i) {
- yl[i] = y4[i]; sumy += yl[i];
- }
+ yl[0] = {y4[0], y4[4], y4[2], y4[6]};
+ yl[1] = {y4[1], y4[5], y4[3], y4[7]};
const int ibl = ib32 / (QK_IQ1BN / 32);
const int ib = ib32 % (QK_IQ1BN / 32);
@@ -5084,15 +5084,15 @@ void kernel_mul_mv_iq1_bn_f32_impl(
for (int row = 0; row < N_DST; row++) {
uint8_t signs = extra[0] >> (4*ib + ir);
- float acc = 0.f;
uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)];
uint32_t v32 = v | (v << 12);
- aux32[0] = v32 & 0x03030303; aux32[1] = v32 & 0x0c0c0c0c;
- acc = yl[0] * aux8[0] + yl[4] * aux8[1] + yl[2]*aux8[2] + yl[6]*aux8[3];
- acc += (yl[1] * aux8[4] + yl[5] * aux8[5] + yl[3]*aux8[6] + yl[7]*aux8[7]) * 0.25f;
+ aux32[0] = v32 & 0x03030303; aux32[1] = (v32 >> 2) & 0x03030303;
+ float4 acc4 = yl[0] * float4{values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]}
+ + yl[1] * float4{values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]};
+ float acc = acc4[0] + acc4[1] + acc4[2] + acc4[3];
- sumf[row] += (signs & 1 ? sumy-acc : acc-sumy);
+ sumf[row] += (signs & 1 ? -acc : acc);
extra += nb*sizeof(block_iq1_bn);
ql += nb*sizeof(block_iq1_bn);
@@ -5956,8 +5956,8 @@ template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
uint8_t gs = xb->extra >> 2*il;
- const float d1 = gs & 1 ? -1 : 1;
- const float d2 = gs & 2 ? -1 : 1;
+ const half d1 = gs & 1 ? -1.h : 1.h;
+ const half d2 = gs & 2 ? -1.h : 1.h;
uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
@@ -5966,12 +5966,15 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4
uint32_t aux32;
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
+ const half values[3] = {-1.h, 0.h, 1.h};
+
+#pragma unroll(4)
for (int i = 0; i < 4; ++i) {
aux32 = (v >> 2*i) & 0x03030303;
- reg[0][i] = d1*aux8[0] - d1;
- reg[1][i] = d1*aux8[1] - d1;
- reg[2][i] = d2*aux8[2] - d2;
- reg[3][i] = d2*aux8[3] - d2;
+ reg[0][i] = d1*values[aux8[0]];
+ reg[1][i] = d1*values[aux8[1]];
+ reg[2][i] = d2*values[aux8[2]];
+ reg[3][i] = d2*values[aux8[3]];
}
}
diff --git a/llama.cpp b/llama.cpp
index 8d2be592..f8d6911b 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4130,7 +4130,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M :return "IQ1_M - 1.75 bpw";
- case LLAMA_FTYPE_MOSTLY_IQ1_BN :return "IQ1_BN - 1.75 bpw Bitnet";
+ case LLAMA_FTYPE_MOSTLY_IQ1_BN :return "IQ1_BN - 1.625 bpw Bitnet";
case LLAMA_FTYPE_MOSTLY_IQ2_BN :return "IQ2_BN - 2.00 bpw Bitnet";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";