summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp24
1 files changed, 13 insertions, 11 deletions
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
index 1c8a991d..9a34270b 100644
--- a/ggml/src/iqk/iqk_mul_mat.cpp
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -12735,7 +12735,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
uint8x16_t qx[8];
- int32x4_t acc[nrc_y] = {};
+ float32x4_t acc[nrc_y] = {};
auto ms = vdup_n_u16(0x8000);
auto mask = vdupq_n_s8(0x03);
float d8[4*nrc_y];
@@ -14140,7 +14140,7 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i
auto scale2 = vmulq_f32(scale2_x, scale_y);
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
- acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f);
}
}
}
@@ -14823,11 +14823,11 @@ inline float32x4_t v_tanh(float32x4_t x) {
return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
//return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
}
-inline float32x4_t v_tanh(float16x8_t x) {
- auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
- auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
- return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
-}
+//inline float32x4_t v_tanh(float16x8_t x) {
+// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
+// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
+// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+//}
inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t zero = vdupq_n_f32(0.0f);
@@ -15671,7 +15671,9 @@ struct HelperQ60 final : public BaseHelper<step> {
auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
#ifdef __aarch64__
// TODO
- auto vd = F16::set1(*(const float16_t *)&dl->d);
+ const float16_t * d16 = (const float16_t *)&dl->d;
+ auto vd = F16::set1(d16[0]);
+ //auto vd = F16::set1(*(const float16_t *)&dl->d);
auto qh8 = vld1_u8(dl->qh);
auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
auto qs = vld1q_u8(dl->qs);
@@ -15819,7 +15821,7 @@ struct FlashMS {
return vmaxvq_f32(vmax);
}
inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) {
- auto vzero = vdupq_n_f32(0);
+ auto vzero = vdupq_n_f16(0);
auto vinf = vdupq_n_f32(-INFINITY);
for (int l = 0; l < k_step/8; ++l) {
auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l));
@@ -15827,9 +15829,9 @@ struct FlashMS {
auto vm2 = vzip2q_u16(vm, vm);
auto kq = vld1q_f32_x2(cache + k_step*j + 8*l);
vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1),
- vbicq_u32(vinf, vm1)));
+ vbicq_u32(vreinterpretq_u32_f32(vinf), vm1)));
vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2),
- vbicq_u32(vinf, vm2)));
+ vbicq_u32(vreinterpretq_u32_f32(vinf), vm2)));
}
float32x4_t vmax = vdupq_n_f32(-INFINITY);
auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale));