summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-13 14:30:30 +0300
committerGitHub <noreply@github.com>2024-10-13 14:30:30 +0300
commitbaab1d9a1e5d28bddb91dd962223be558bf7737d (patch)
tree4042dc7d5a565dcb08e86efc3bcecf1ef3ceb668
parent910a13409463f7aedb0a92be013a1b9bb50f4859 (diff)
Fix and optimize iq2k Metal implementation (#86)
* I somehow broke iq2_k on Metal? - fix dequantize * I somehow broke iq2_k on Metal? - fix dot product * iq2_k: optimize Metal dot product 42.6 t/s -> 46.2 t/s --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/src/ggml-metal.m19
-rw-r--r--ggml/src/ggml-metal.metal42
2 files changed, 35 insertions, 26 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index d5e8d6ae..ca08b0f5 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -2286,15 +2286,15 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 ||
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
- src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
+ src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0 ||
src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_IQ2_KS) {
- const int mem_size = 64*sizeof(float);
+ else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@@ -2693,13 +2693,18 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t _ne1 = 1;
const int tgz = dst_rows;
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q6_0 ||
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
- src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| src0t == GGML_TYPE_IQ2_KS ||
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 ||
+ src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 5ed424d3..fe197309 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -6173,7 +6173,7 @@ void kernel_mul_mv_iq2_k_f32_impl(
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[N_DST]={0.f};
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
@@ -6183,9 +6183,14 @@ void kernel_mul_mv_iq2_k_f32_impl(
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+ threadgroup float * all_values = (threadgroup float *)shared_values + 8*sgitg;
+ {
+ if (tiisg < 8) all_values[tiisg] = kvalues_iq2k_f[tiisg];
+ simdgroup_barrier(mem_flags::mem_none);
+ }
+
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
- uint16_t shift[4];
for (int ib = ix; ib < nb; ib += 4) {
@@ -6202,33 +6207,31 @@ void kernel_mul_mv_iq2_k_f32_impl(
device const uint32_t * q32 = (device const uint32_t *)xb.qs + 8*iq + 2*ir;
device const uint32_t * sc = (device const uint32_t *)xb.scales;
- const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1;
+ const uint32_t scales32 = (sc[iq] >> 4*is) & 0x0f0f0f0f;
thread const int8_t * s8 = (thread const int8_t *)&scales32;
- uint16_t extra = xb.extra >> (8*iq + is);
-
- shift[0] = (extra << 2) & 4;
- shift[1] = (extra << 1) & 4;
- shift[2] = (extra >> 0) & 4;
- shift[3] = (extra >> 1) & 4;
+ uint16_t extra = (xb.extra >> (8*iq + is)) << 2;
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
- constant float * values = kvalues_iq2k_f + shift[l];
+ threadgroup const float * values = all_values + (extra & 4);
aux32[0] = (q32[0] >> 2*l) & 0x03030303;
aux32[1] = (q32[1] >> 2*l) & 0x03030303;
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
+ extra >>= 2;
}
- sumf[row] += (float)xb.d * (acc[0] * (s8[0] - 15) + acc[1] * (s8[1] - 15) + acc[2] * (s8[2] - 15) + acc[3] * (s8[3] - 15));
+
+ sumf[row] += (float)xb.d * (acc[0] * s8[0] + acc[1] * s8[1] + acc[2] * s8[2] + acc[3] * s8[3] - 8.f*(acc[0] + acc[1] + acc[2] + acc[3]));
}
y4 += 4 * QK_K;
}
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ for (int row = 0; row < N_DST; row += 2) {
+ float2 tmp{sumf[row], sumf[row+1]};
+ tmp = simd_sum(tmp);
+ if (tiisg < 2) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = tmp[tiisg];
}
}
}
@@ -6254,11 +6257,12 @@ kernel void kernel_mul_mv_iq2_k_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq2_ks_f32_impl(
@@ -7705,10 +7709,10 @@ template <typename type4x4>
void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256
device const uint32_t * q32 = (device const uint32_t *)xb->qs + 8*(il/8) + 4*(il&1);
- half d = xb->d * (2*((xb->scales[il/2] >> 4*(il&1)) & 0xf) - 15);
+ half d = xb->d * (((xb->scales[il/2] >> 4*(il&1)) & 0xf) - 8);
- constant int8_t * int_values = iq2nl_values + 4*((xb->extra >> il) & 1);
- half4 values = { d * int_values[0], d * int_values[1], d * int_values[2], d * int_values[3] };
+ constant half4 * half_values = (constant half4 *)kvalues_iq2k_h;
+ half4 values = half_values[(xb->extra >> il) & 1] * d;
const int shift = 2*((il%8)/2);
uint32_t aux32;