diff options
-rw-r--r-- | ggml/src/ggml-metal.metal | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3a112cb7..988a820f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -5233,8 +5233,6 @@ void kernel_mul_mv_iq2_k_f32_impl( float yl[32]; float sumf[N_DST]={0.f}, all_sum; - const int step = (sizeof(block_q2_K) * nb) / 4; - const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -5243,18 +5241,11 @@ void kernel_mul_mv_iq2_k_f32_impl( device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - uint32_t aux32; - thread const uint8_t * aux8 = (thread const uint8_t *)&aux32; + uint32_t aux32[2]; + thread const uint8_t * aux8 = (thread const uint8_t *)aux32; for (int ib = ix; ib < nb; ib += 4) { - //float4 sumy = {0.f, 0.f, 0.f, 0.f}; - //for (int i = 0; i < 8; ++i) { - // yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - // yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - // yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - // yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - //} for (int i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; yl[i+ 8] = y4[i+32]; @@ -5276,16 +5267,11 @@ void kernel_mul_mv_iq2_k_f32_impl( for (int l = 0; l < 4; ++l) { constant float * values = kvalues_iq2k_f + 4*(extra & 1); extra >>= 2; - for (int i = 0; i < 2; ++i) { - aux32 = (q32[i] >> 2*l) & 0x03030303; - acc[l] += values[aux8[0]] * yl[8*l + 4*i + 0] + - + values[aux8[1]] * yl[8*l + 4*i + 1] + - + values[aux8[2]] * yl[8*l + 4*i + 2] + - + values[aux8[3]] * yl[8*l + 4*i + 3]; - } + aux32[0] = (q32[0] >> 2*l) & 0x03030303; + aux32[1] = (q32[1] >> 2*l) & 0x03030303; + for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]]; } - - sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) * acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15)); + sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) + acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15)); } |