summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-18 13:42:42 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commitfece7e1db7bf73497a32751af06c6dbf48c26b19 (patch)
tree98a89902822006586a4454e092b679d7e746b212
parent4f51348d3d5c0f0bfee42d0a7efc81030f046d13 (diff)
Bitnet(2.25 bpw): faster Metal dot product
With this we get TG-128 = 97 t/s.
-rw-r--r--ggml-metal.metal36
1 files changed, 19 insertions, 17 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index b9be74b2..e5ef552c 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -5127,16 +5127,16 @@ void kernel_mul_mv_iq2_bn_f32_impl(
device const block_iq2_bn * x = (device const block_iq2_bn *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float yl[8];
+ float yl[16];
float sumf[N_DST]={0.f};
float d1bn[N_DST];
const int nb32 = nb * (QK_IQ1BN / 32);
- const int ix = tiisg/8; // 0...3
- const int ir = tiisg%8; // 0...7
+ const int ix = tiisg/4; // 0...7
+ const int ir = tiisg%4; // 0...3
- device const float * y4 = y + 64 * ix + 2 * ir;
+ device const float * y4 = y + 64 * ix + 4 * ir;
typedef union { float f; uint32_t i; } scale_t;
scale_t scale;
@@ -5145,32 +5145,34 @@ void kernel_mul_mv_iq2_bn_f32_impl(
d1bn[row] = x[nb*row].d;
}
- for (int ib = ix; ib < nb; ib += 4) {
+ for (int ib = ix; ib < nb; ib += 8) {
float sumy = 0.f;
- for (int i = 0; i < 2; ++i) {
- yl[i+0] = y4[i+ 0]; sumy += yl[i];
- yl[i+2] = y4[i+16]; sumy += yl[i+2];
- yl[i+4] = y4[i+32]; sumy += yl[i+4];
- yl[i+6] = y4[i+48]; sumy += yl[i+6];
+ for (int i = 0; i < 4; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
+ yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4];
+ yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
+ yl[i+12] = y4[i+48]; sumy += yl[i+12];
}
- device const uint8_t * qs = x[ib].qs + 2*ir;
+ device const uint8_t * qs = x[ib].qs + 4*ir;
for (int row = 0; row < N_DST; row++) {
- float acc = 0;
- for (int j = 0; j < 2; ++j) {
- acc += yl[j+0] * ((qs[j] >> 0) & 0x03) + yl[j+2] * ((qs[j] >> 2) & 0x03)
- + yl[j+4] * ((qs[j] >> 4) & 0x03) + yl[j+6] * ((qs[j] >> 6) & 0x03);
+ float4 acc = {0.f};
+ for (int j = 0; j < 4; ++j) {
+ acc[0] += yl[j+ 0] * (qs[j] & 0x03);
+ acc[1] += yl[j+ 4] * (qs[j] & 0x0c);
+ acc[2] += yl[j+ 8] * (qs[j] & 0x30);
+ acc[3] += yl[j+12] * (qs[j] & 0xc0);
}
- sumf[row] += acc - sumy;
+ sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy;
qs += nb*sizeof(block_iq2_bn);
}
- y4 += 64 * 4;
+ y4 += 64 * 8;
}
for (int row = 0; row < N_DST; row += 2) {