diff options
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1c8a991d..9a34270b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12735,7 +12735,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI int nb = n / 32; GGML_ASSERT(nb%4 == 0); uint8x16_t qx[8]; - int32x4_t acc[nrc_y] = {}; + float32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); auto mask = vdupq_n_s8(0x03); float d8[4*nrc_y]; @@ -14140,7 +14140,7 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i auto scale2 = vmulq_f32(scale2_x, scale_y); info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0]))); info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1]))); - acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f); } } } @@ -14823,11 +14823,11 @@ inline float32x4_t v_tanh(float32x4_t x) { return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } -inline float32x4_t v_tanh(float16x8_t x) { - auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); - auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); - return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); -} +//inline float32x4_t v_tanh(float16x8_t x) { +// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); +// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); +// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +//} inline float32x4_t v_silu(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t zero = vdupq_n_f32(0.0f); @@ -15671,7 +15671,9 @@ struct HelperQ60 final : public BaseHelper<step> { auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0; #ifdef __aarch64__ // TODO - auto vd = F16::set1(*(const float16_t *)&dl->d); + const float16_t * d16 = (const float16_t *)&dl->d; + auto vd = F16::set1(d16[0]); + //auto vd = F16::set1(*(const float16_t *)&dl->d); auto qh8 = vld1_u8(dl->qh); auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); auto qs = vld1q_u8(dl->qs); @@ -15819,7 +15821,7 @@ struct FlashMS { return vmaxvq_f32(vmax); } inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) { - auto vzero = vdupq_n_f32(0); + auto vzero = vdupq_n_f16(0); auto vinf = vdupq_n_f32(-INFINITY); for (int l = 0; l < k_step/8; ++l) { auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l)); @@ -15827,9 +15829,9 @@ struct FlashMS { auto vm2 = vzip2q_u16(vm, vm); auto kq = vld1q_f32_x2(cache + k_step*j + 8*l); vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1), - vbicq_u32(vinf, vm1))); + vbicq_u32(vreinterpretq_u32_f32(vinf), vm1))); vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2), - vbicq_u32(vinf, vm2))); + vbicq_u32(vreinterpretq_u32_f32(vinf), vm2))); } float32x4_t vmax = vdupq_n_f32(-INFINITY); auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); |