diff options
Diffstat (limited to 'ggml-cuda')
-rw-r--r-- | ggml-cuda/convert.cu | 26 | ||||
-rw-r--r-- | ggml-cuda/vecdotq.cuh | 76 |
2 files changed, 49 insertions, 53 deletions
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index a0a826e9..4a67c498 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -425,7 +425,10 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst const int64_t ii = blockIdx.x; const block_iq1_bn * x = (const block_iq1_bn *) vx; - static const uint16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + +//#define COMPUTE_VS(v) 3*v >> 8 +#define COMPUTE_VS(v) (v + (v >> 1)) >> 7 const int tid = threadIdx.x; const int il = tid/4; // 0...7 @@ -433,11 +436,24 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst dst_t * y = yy + ii*QK_K + 64*ib + 8*il; int64_t i = QK_K/QK_IQ1BN * ii + ib; if (i >= nb64) return; - uint16_t val = x[i].ql[il] | ((x[i].qh[il%4] << (8 - 4*(il/4))) & 0x0f00) | ((x[i].extra << (12 - il)) & 4096); - for (int j = 0; j < 8; ++j) { - uint16_t v = (val*k_mult[j] & 0x1fff)*3 >> 13; - y[j] = v - 1; + const int i16 = il/2; + uint8_t q = x[i].ql[3*i16+2*(il%2)]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = COMPUTE_VS(v); + y[2*(il%2)+j] = vs - 1; } + q = x[i].ql[3*i16+1]; + for (int j = 0; j < 2; ++j) { + uint8_t v = k_mult[3*(il%2)+j]*q; + int8_t vs = COMPUTE_VS(v); + y[5*(1-(il%2))+j] = vs-1; + } + uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q; + int8_t vs = COMPUTE_VS(v); + y[7] = vs - 1; + +#undef COMPUTE_VS } template<typename dst_t> diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index d0d2e923..f0133e07 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -1078,67 +1078,47 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; - static const int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + static const uint8_t k_mult[5] = {81, 27, 9, 3, 1}; // iqs is 0 or 1 - uint8_t extra = bq1->extra >> 4*iqs; int sumi = 0; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const int * q8 = (const int *)bq8_1[iqs].qs; - //int v[2]; - //int8_t * a = (int8_t *)v; - //for (int l = 0; l < 2; ++l) { - // int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - // int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - // for (int k = 0; k < 8; ++k) a[k] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], sumi)); - // for (int k = 0; k < 8; ++k) a[k] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // sumi = __dp4a(v[0], q8[4*l+2], __dp4a(v[1], q8[4*l+3], sumi)); - // extra >>= 2; - //} - - int v[4]; - int8_t * a = (int8_t *)v; + int val[4]; for (int l = 0; l < 2; ++l) { - int16_t val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - int16_t val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - for (int k = 0; k < 8; ++k) { - a[k+0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1; - a[k+8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1; + int8_t * a = (int8_t *)val; + const int i16 = 2*iqs + l; + for (int k = 0; k < 3; ++k) { + uint8_t q = bq1->ql[3*i16+k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + *a++ = vs-1; + } } - sumi = __dp4a(v[0], q8[4*l+0], __dp4a(v[1], q8[4*l+1], __dp4a(v[2], q8[4*l+2], __dp4a(v[3], q8[4*l+3], sumi)))); - extra >>= 2; + uint8_t v = k_mult[i16]*bq1->extra; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + *a++ = vs-1; + sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi)))); } - - //int v[8]; - //int8_t * a = (int8_t *)v; - //int16_t val1 = bq1->ql[4*iqs + 0] | ((bq1->qh[0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - //int16_t val2 = bq1->ql[4*iqs + 1] | ((bq1->qh[1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - //int16_t val3 = bq1->ql[4*iqs + 2] | ((bq1->qh[2] << (8-4*iqs)) & 0x0f00) | ((extra << 10) & 4096); - //int16_t val4 = bq1->ql[4*iqs + 3] | ((bq1->qh[3] << (8-4*iqs)) & 0x0f00) | ((extra << 9) & 4096); - //for (int k = 0; k < 8; ++k) { - // a[k+ 0] = ((val1*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // a[k+ 8] = ((val2*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // a[k+16] = ((val3*k_mult[k] & 0x1fff)*3 >> 13) - 1; - // a[k+24] = ((val4*k_mult[k] & 0x1fff)*3 >> 13) - 1; - //} - //sumi = __dp4a(v[0], q8[0], __dp4a(v[1], q8[1], __dp4a(v[2], q8[2], __dp4a(v[3], q8[3], sumi)))); - //sumi = __dp4a(v[4], q8[4], __dp4a(v[5], q8[5], __dp4a(v[6], q8[6], __dp4a(v[7], q8[7], sumi)))); #else const int8_t * q8 = bq8_1[iqs].qs; for (int l = 0; l < 2; ++l) { - int val1 = bq1->ql[4*iqs + 2*l+0] | ((bq1->qh[2*l+0] << (8-4*iqs)) & 0x0f00) | ((extra << 12) & 4096); - int val2 = bq1->ql[4*iqs + 2*l+1] | ((bq1->qh[2*l+1] << (8-4*iqs)) & 0x0f00) | ((extra << 11) & 4096); - for (int k = 0; k < 4; ++k) { - int v1 = (val1*k_mult[k+0] & 0x1fff)*3 >> 13; - int v2 = (val1*k_mult[k+4] & 0x1fff)*3 >> 13; - int v3 = (val2*k_mult[k+0] & 0x1fff)*3 >> 13; - int v4 = (val2*k_mult[k+4] & 0x1fff)*3 >> 13; - sumi += (v1 - 1)*q8[k+0] + (v2 - 1)*q8[k+4] + (v3 - 1)*q8[k+8] + (v4 - 1)*q8[k+12]; + const int i16 = 2*iqs + l; + for (int k = 0; k < 3; ++k) { + uint8_t q = bq1->ql[3*i16+k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = (v + (v >> 1)) >> 7; + sumi += q8[j]*(vs - 1); + } + q8 += 5; } - q8 += 16; - extra >>= 2; + uint8_t v = k_mult[i16]*bq1->extra; + int8_t vs = (v + (v >> 1)) >> 7; + sumi += q8[0]*(vs - 1); + q8++; } #endif return __low2float(bq8_1[iqs].ds) * sumi; |