diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-17 11:51:20 +0200 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:51 +0300 |
commit | d42e9e2922a836f837b33f4e5f768c4fa6de22ba (patch) | |
tree | 705641e3cef99b3aef15664cb6ffba9da26b9980 | |
parent | 9d58489c33061c9ba18be045984b2f87fecb837b (diff) |
iq1_bn(Metal): 64 -> 66.2 t/s for TG
This should be good enough. One cannot ask
Apple Silicon to do too much work.
-rw-r--r-- | ggml-metal.metal | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index 097dd164..7f94e133 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5046,8 +5046,10 @@ void kernel_mul_mv_iq1_bn_f32_impl( d1bn[row] = scale.f; } - uint16_t aux16; - thread const uint8_t * aux8 = (thread const uint8_t *)&aux16; + //uint32_t aux32; + //thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + uint32_t aux32[2]; + thread const uint8_t * aux8 = (thread const uint8_t *)aux32; for (int ib32 = ix; ib32 < nb32; ib32 += 8) { @@ -5069,15 +5071,15 @@ void kernel_mul_mv_iq1_bn_f32_impl( uint8_t signs = extra[0] >> (8 + 4*ib + ir); float acc = 0.f; - uint16_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)]; - aux16 = v & 0x0303; - acc += yl[0] * aux8[0] + yl[4] * aux8[1]; - aux16 = (v >> 2) & 0x0303; - acc += yl[1] * aux8[0] + yl[5] * aux8[1]; - aux16 = (v >> 4) & 0x0303; - acc += yl[2] * aux8[0] + yl[6] * aux8[1]; - aux16 = (v >> 6) & 0x0303; - acc += yl[3] * aux8[0] + yl[7] * aux8[1]; + uint32_t v = iq1bn_grid_u16[ql[0] | ((qh[0] << (8 - 4*(ir%2))) & 0x0f00)]; + uint32_t v32 = v | (v << 12); + //aux32 = v32 & 0x03030303; + //acc += yl[0] * aux8[0] + yl[4] * aux8[1] + yl[2]*aux8[2] + yl[6]*aux8[3]; + //aux32 = v32 & 0x0c0c0c0c; + //acc += (yl[1] * aux8[0] + yl[5] * aux8[1] + yl[3]*aux8[2] + yl[7]*aux8[3]) * 0.25f; + aux32[0] = v32 & 0x03030303; aux32[1] = v32 & 0x0c0c0c0c; + acc = yl[0] * aux8[0] + yl[4] * aux8[1] + yl[2]*aux8[2] + yl[6]*aux8[3]; + acc += (yl[1] * aux8[4] + yl[5] * aux8[5] + yl[3]*aux8[6] + yl[7]*aux8[7]) * 0.25f; sumf[row] += (signs & 1 ? sumy-acc : acc-sumy); |