diff options
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r-- | ggml-metal.metal | 117 |
1 files changed, 87 insertions, 30 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index cca4b0e9..2bf6894a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -5054,6 +5054,49 @@ static inline float iq1bn_fp8_to_float(uint8_t fp8) { return s.f; } +//static constant int8_t iq1bn_values[256*5] = { +// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0, +// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1, +// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1, +// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1, +// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1, +// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1, +// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0, +// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1, +// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1, +// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1, +// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0, +// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1, +// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1, +// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1, +// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0, +// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1, +// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1, +// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0, +// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1, +// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0, +// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1, +// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0, +// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0, +// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0, +// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0, +// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0, +// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1, +// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0, +// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0, +// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1, +// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1, +// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0, +// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1, +// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, +// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1, +// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1, +// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1, +// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, +// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1, +// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, +//}; + void kernel_mul_mv_iq1_bn_f32_impl( device const void * src0, device const float * src1, @@ -5087,53 +5130,62 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float4 yl[2]; - float sumf[N_DST]={0.f}, all_sum; + float yl[16]; + float sumf[N_DST]={0.f}; const int nb32 = nb * (QK_IQ1BN / 32); - const int ix = tiisg/4; - const int ir = tiisg%4; + const int ix = tiisg/2; + const int ir = tiisg%2; - device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir; + device const float * y4 = (device const float *)y + 32 * ix + 16 * ir; uint32_t aux32[2]; thread const uint8_t * aux8 = (thread const uint8_t *)aux32; const float values[3] = {-1.f, 0.f, 1.f}; - constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; + constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; - for (int ib32 = ix; ib32 < nb32; ib32 += 8) { + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - yl[0] = y4[0]; yl[1] = y4[1]; + for (int j = 0; j < 16; ++j) yl[j] = y4[j]; const int ibl = ib32 / (QK_IQ1BN / 32); const int ib = ib32 % (QK_IQ1BN / 32); - const int il = 4*ib + ir; + const int i16 = 2*ib + ir; device const block_iq1_bn * xr = x + ibl; + device const uint8_t * ql = xr->ql + 3*i16; device const uint8_t * extra = (device const uint8_t *)&xr->extra; - device const uint8_t * ql = xr->ql + il; - device const uint8_t * qh = xr->qh + il%4; for (int row = 0; row < N_DST; row++) { - uint8_t h = extra[0] >> il; + float acc = 0; + int i = 0; + for (int k = 0; k < 3; ++k) { + //constant int8_t * vs = iq1bn_values + 5*ql[k]; + //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j]; + uint8_t q = ql[k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + v = 3*v >> 8; //(v + (v >> 1)) >> 7; + acc += yl[i++] * values[v]; + } + } + //constant int8_t * vs = iq1bn_values + 5*extra[0]; + //acc += yl[15] * vs[i16]; + uint8_t v = k_mult[i16]*extra[0]; + v = 3*v >> 8; //(v + (v >> 1)) >> 7; + acc += yl[15] * values[v]; - int16_t val = ql[0] | ((qh[0] << (8 - 4*(il/4))) & 0x0f00) | ((extra[0] << (12 - il)) & 4096); - float4 acc4 = yl[0] * float4{values[(val*k_mult[0] & 0x1fff)*3 >> 13], values[(val*k_mult[1] & 0x1fff)*3 >> 13], - values[(val*k_mult[2] & 0x1fff)*3 >> 13], values[(val*k_mult[3] & 0x1fff)*3 >> 13]} - + yl[1] * float4{values[(val*k_mult[4] & 0x1fff)*3 >> 13], values[(val*k_mult[5] & 0x1fff)*3 >> 13], - values[(val*k_mult[6] & 0x1fff)*3 >> 13], values[(val*k_mult[7] & 0x1fff)*3 >> 13]}; - sumf[row] += acc4[0] + acc4[1] + acc4[2] + acc4[3]; + sumf[row] += acc; extra += nb*sizeof(block_iq1_bn); ql += nb*sizeof(block_iq1_bn); - qh += nb*sizeof(block_iq1_bn); } - y4 += 32 * 2; + y4 += 32 * 16; } for (int row = 0; row < N_DST; row += 2) { @@ -5990,18 +6042,23 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & template <typename type4x4> void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 - uint8_t gs = xb->extra >> 2*il; - - constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1}; - - short il1 = 2*il+0, il2 = 2*il+1; - int16_t v1 = xb->ql[il1] | ((xb->qh[il1%4] << (8 - 4*(il1/4))) & 0x0f00) | ((gs << 12) & 4096); - int16_t v2 = xb->ql[il2] | ((xb->qh[il2%4] << (8 - 4*(il2/4))) & 0x0f00) | ((gs << 11) & 4096); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = ((v1*k_mult[i] & 0x1fff)*3 >> 13) - 1; - reg[i/4+2][i%4] = ((v2*k_mult[i] & 0x1fff)*3 >> 13) - 1; + constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1}; + + int i = 0; + for (int k = 0; k < 3; ++k) { + uint8_t q = xb->ql[3*il + k]; + for (int j = 0; j < 5; ++j) { + uint8_t v = k_mult[j]*q; + int8_t vs = 3*v >> 8; + //int8_t vs = (v + (v >> 1)) >> 7; + reg[i/4][i%4] = vs - 1; + ++i; + } } + uint8_t v = k_mult[il]*xb->extra; + int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7; + reg[3][3] = vs - 1; } template <typename type4x4> |