diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-07-31 14:53:54 +0200 |
---|---|---|
committer | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-01 09:38:06 +0200 |
commit | b572dd534779541d5da415146f695e6d26684a33 (patch) | |
tree | 65eeabfb6f0fa38a195228a51ea75ad0b7ef32b7 | |
parent | 394ed3913c16ad9baae24e93e126847030063fad (diff) |
iq2/3_k: tiny bit faster Metal dot products
-rw-r--r-- | ggml/src/ggml-metal.metal | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 88fe607b..53b2ddb8 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -5246,6 +5246,7 @@ void kernel_mul_mv_iq2_k_f32_impl( uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; + uint16_t shift[4]; for (int ib = ix; ib < nb; ib += 4) { @@ -5266,10 +5267,14 @@ void kernel_mul_mv_iq2_k_f32_impl( thread const int8_t * s8 = (thread const int8_t *)&scales32; uint16_t extra = xb.extra >> (8*iq + is); + shift[0] = (extra << 2) & 4; + shift[1] = (extra << 1) & 4; + shift[2] = (extra >> 0) & 4; + shift[3] = (extra >> 1) & 4; + float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - constant float * values = kvalues_iq2k_f + 4*(extra & 1); - extra >>= 2; + constant float * values = kvalues_iq2k_f + shift[l]; aux32[0] = (q32[0] >> 2*l) & 0x03030303; aux32[1] = (q32[1] >> 2*l) & 0x03030303; for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; @@ -5365,6 +5370,7 @@ void kernel_mul_mv_iq3_k_f32_impl( uint32_t vl[2], vh[2]; uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; + uint16_t shift[4]; for (int ib = ix; ib < nb; ib += 4) { @@ -5387,6 +5393,11 @@ void kernel_mul_mv_iq3_k_f32_impl( uint16_t extra = xb.extra >> (8*iq + is); uint16_t signs = xb.scales_h >> (8*iq + is); + shift[0] = (extra << 3) & 8; + shift[1] = (extra << 2) & 8; + shift[2] = (extra << 1) & 8; + shift[3] = (extra << 0) & 8; + vl[0] = ql16[0] | ql16[1] << 16; vl[1] = ql16[2] | ql16[3] << 16; vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2; @@ -5394,8 +5405,7 @@ void kernel_mul_mv_iq3_k_f32_impl( float4 acc = {0.f}; for (int l = 0; l < 4; ++l) { - constant float * values = kvalues_iq3k_f + 8*(extra & 1); - extra >>= 2; + constant float * values = kvalues_iq3k_f + shift[l]; aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404); aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404); for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; |