summaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal117
1 files changed, 87 insertions, 30 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index cca4b0e9..2bf6894a 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -5054,6 +5054,49 @@ static inline float iq1bn_fp8_to_float(uint8_t fp8) {
return s.f;
}
+//static constant int8_t iq1bn_values[256*5] = {
+// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0,
+// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1,
+// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1,
+// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1,
+// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1,
+// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1,
+// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0,
+// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1,
+// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1,
+// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1,
+// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0,
+// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1,
+// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1,
+// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1,
+// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0,
+// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1,
+// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1,
+// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0,
+// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1,
+// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0,
+// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1,
+// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0,
+// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0,
+// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0,
+// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0,
+// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0,
+// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1,
+// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0,
+// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0,
+// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1,
+// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1,
+// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0,
+// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1,
+// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0,
+// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
+// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1,
+// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1,
+// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,
+// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1,
+// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+//};
+
void kernel_mul_mv_iq1_bn_f32_impl(
device const void * src0,
device const float * src1,
@@ -5087,53 +5130,62 @@ void kernel_mul_mv_iq1_bn_f32_impl(
device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float4 yl[2];
- float sumf[N_DST]={0.f}, all_sum;
+ float yl[16];
+ float sumf[N_DST]={0.f};
const int nb32 = nb * (QK_IQ1BN / 32);
- const int ix = tiisg/4;
- const int ir = tiisg%4;
+ const int ix = tiisg/2;
+ const int ir = tiisg%2;
- device const float4 * y4 = (device const float4 *)y + 8 * ix + 2 * ir;
+ device const float * y4 = (device const float *)y + 32 * ix + 16 * ir;
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
const float values[3] = {-1.f, 0.f, 1.f};
- constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
- for (int ib32 = ix; ib32 < nb32; ib32 += 8) {
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
- yl[0] = y4[0]; yl[1] = y4[1];
+ for (int j = 0; j < 16; ++j) yl[j] = y4[j];
const int ibl = ib32 / (QK_IQ1BN / 32);
const int ib = ib32 % (QK_IQ1BN / 32);
- const int il = 4*ib + ir;
+ const int i16 = 2*ib + ir;
device const block_iq1_bn * xr = x + ibl;
+ device const uint8_t * ql = xr->ql + 3*i16;
device const uint8_t * extra = (device const uint8_t *)&xr->extra;
- device const uint8_t * ql = xr->ql + il;
- device const uint8_t * qh = xr->qh + il%4;
for (int row = 0; row < N_DST; row++) {
- uint8_t h = extra[0] >> il;
+ float acc = 0;
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ //constant int8_t * vs = iq1bn_values + 5*ql[k];
+ //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j];
+ uint8_t q = ql[k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[i++] * values[v];
+ }
+ }
+ //constant int8_t * vs = iq1bn_values + 5*extra[0];
+ //acc += yl[15] * vs[i16];
+ uint8_t v = k_mult[i16]*extra[0];
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[15] * values[v];
- int16_t val = ql[0] | ((qh[0] << (8 - 4*(il/4))) & 0x0f00) | ((extra[0] << (12 - il)) & 4096);
- float4 acc4 = yl[0] * float4{values[(val*k_mult[0] & 0x1fff)*3 >> 13], values[(val*k_mult[1] & 0x1fff)*3 >> 13],
- values[(val*k_mult[2] & 0x1fff)*3 >> 13], values[(val*k_mult[3] & 0x1fff)*3 >> 13]}
- + yl[1] * float4{values[(val*k_mult[4] & 0x1fff)*3 >> 13], values[(val*k_mult[5] & 0x1fff)*3 >> 13],
- values[(val*k_mult[6] & 0x1fff)*3 >> 13], values[(val*k_mult[7] & 0x1fff)*3 >> 13]};
- sumf[row] += acc4[0] + acc4[1] + acc4[2] + acc4[3];
+ sumf[row] += acc;
extra += nb*sizeof(block_iq1_bn);
ql += nb*sizeof(block_iq1_bn);
- qh += nb*sizeof(block_iq1_bn);
}
- y4 += 32 * 2;
+ y4 += 32 * 16;
}
for (int row = 0; row < N_DST; row += 2) {
@@ -5990,18 +6042,23 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
- uint8_t gs = xb->extra >> 2*il;
-
- constexpr int16_t k_mult[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
-
- short il1 = 2*il+0, il2 = 2*il+1;
- int16_t v1 = xb->ql[il1] | ((xb->qh[il1%4] << (8 - 4*(il1/4))) & 0x0f00) | ((gs << 12) & 4096);
- int16_t v2 = xb->ql[il2] | ((xb->qh[il2%4] << (8 - 4*(il2/4))) & 0x0f00) | ((gs << 11) & 4096);
- for (int i = 0; i < 8; ++i) {
- reg[i/4+0][i%4] = ((v1*k_mult[i] & 0x1fff)*3 >> 13) - 1;
- reg[i/4+2][i%4] = ((v2*k_mult[i] & 0x1fff)*3 >> 13) - 1;
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = xb->ql[3*il + k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8;
+ //int8_t vs = (v + (v >> 1)) >> 7;
+ reg[i/4][i%4] = vs - 1;
+ ++i;
+ }
}
+ uint8_t v = k_mult[il]*xb->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ reg[3][3] = vs - 1;
}
template <typename type4x4>