diff options
Diffstat (limited to 'ggml-cuda/vecdotq.cuh')
-rw-r--r-- | ggml-cuda/vecdotq.cuh | 35 |
1 files changed, 16 insertions, 19 deletions
diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index b9573a7c..3b12d656 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -265,36 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( // contiguous u/y values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const half2 & dm2, const float & d8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi_d = 0; - int sumi_m = 0; + float sumf_d = 0.0f; + float sumf_m = 0.0f; #pragma unroll for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - int sumi_d_sc = 0; - - const int sc = scales[i0 / (QI8_1/2)]; - - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; + const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); + int sumi_d = 0; + int sumi_m = 0; + const int vi0 = v[i0/(QI8_1/2)]; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product - sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; + sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product + sumi_m = __dp4a(0x01010101, u[i], sumi_m); } - sumi_d += sumi_d_sc * (sc & 0xF); + sumf_d += dm2f.x * sumi_d; + sumf_m += dm2f.y * sumi_m; } - const float2 dm2f = __half22float2(dm2); - - return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); + return d8*(sumf_d - sumf_m); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -352,8 +347,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { int sumi_sc = 0; +#pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); + sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product } sumi += sumi_sc * scales[i0 / (QI8_1/2)]; |