summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-metal.m12
-rw-r--r--ggml/src/ggml-metal.metal37
2 files changed, 27 insertions, 22 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index ca08b0f5..ac183585 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -2290,11 +2290,11 @@ static enum ggml_status ggml_metal_graph_compute(
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0 ||
- src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
+ src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) {
- const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float);
+ else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@@ -2697,11 +2697,11 @@ static enum ggml_status ggml_metal_graph_compute(
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 ||
src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K||
- src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
+ src0t == GGML_TYPE_IQ2_TN|| src0t == GGML_TYPE_IQ1_TN) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K) {
- const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : 16*sizeof(float);
+ else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index fe197309..72595c91 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -6449,8 +6449,14 @@ void kernel_mul_mv_iq3_k_f32_impl(
device const block_iq3_k * x = (device const block_iq3_k *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+ threadgroup float * all_values = (threadgroup float *)shared_values + 16*sgitg;
+ {
+ if (tiisg < 16) all_values[tiisg] = kvalues_iq3k_f[tiisg];
+ simdgroup_barrier(mem_flags::mem_none);
+ }
+
float yl[32];
- float sumf[N_DST]={0.f}, all_sum;
+ float sumf[N_DST]={0.f};
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
@@ -6463,7 +6469,6 @@ void kernel_mul_mv_iq3_k_f32_impl(
uint32_t vl[2], vh[2];
uint32_t aux32[2];
thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
- uint16_t shift[4];
for (int ib = ix; ib < nb; ib += 4) {
@@ -6479,18 +6484,14 @@ void kernel_mul_mv_iq3_k_f32_impl(
device const block_iq3_k & xb = x[row*nb + ib];
device const uint16_t * ql16 = (device const uint16_t *)xb.qs + 16*iq + 4*ir;
device const uint16_t * qh16 = (device const uint16_t *)xb.qh + 4*ir;
- device const uint32_t * sc = (device const uint32_t *)xb.scales_l;
+ device const uint16_t * sc16 = (device const uint16_t *)xb.scales_l;
- const uint32_t scales32 = ((sc[iq] >> 4*is) & 0x0f0f0f0f) << 1;
+ uint32_t scales32 = sc16[2*iq+0] | (sc16[2*iq+1] << 16);
+ scales32 = ((scales32 >> 4*is) & 0x0f0f0f0f) << 1;
thread const int8_t * s8 = (thread const int8_t *)&scales32;
- uint16_t extra = xb.extra >> (8*iq + is);
+ uint16_t extra = (xb.extra >> (8*iq + is)) << 3;
uint16_t signs = xb.scales_h >> (8*iq + is);
- shift[0] = (extra << 3) & 8;
- shift[1] = (extra << 2) & 8;
- shift[2] = (extra << 1) & 8;
- shift[3] = (extra << 0) & 8;
-
vl[0] = ql16[0] | ql16[1] << 16;
vl[1] = ql16[2] | ql16[3] << 16;
vh[0] = ((qh16[0] | (qh16[1] << 16)) << 4*(1-iq)) >> 2;
@@ -6498,12 +6499,14 @@ void kernel_mul_mv_iq3_k_f32_impl(
float4 acc = {0.f};
for (int l = 0; l < 4; ++l) {
- constant float * values = kvalues_iq3k_f + shift[l];
+ threadgroup const float * values = all_values + (extra & 8);
+ //constant float * values = kvalues_iq3k_f + (extra & 8);
aux32[0] = (vl[0] & 0x03030303) | (vh[0] & 0x04040404);
aux32[1] = (vl[1] & 0x03030303) | (vh[1] & 0x04040404);
for (int j = 0; j < 8; ++j) acc[l] += yl[8*l+j] * values[aux8[j]];
vl[0] >>= 2; vl[1] >>= 2;
vh[0] >>= 1; vh[1] >>= 1;
+ extra >>= 2;
}
sumf[row] += (float)xb.d * (acc[0] * (signs & 0x01 ? -s8[0] : s8[0]) + acc[1] * (signs & 0x04 ? -s8[1] : s8[1]) +
@@ -6514,10 +6517,11 @@ void kernel_mul_mv_iq3_k_f32_impl(
y4 += 4 * QK_K;
}
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ for (int row = 0; row < N_DST; row += 2) {
+ float2 tmp{sumf[row], sumf[row+1]};
+ tmp = simd_sum(tmp);
+ if (tiisg < 2) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = tmp[tiisg];
}
}
}
@@ -6543,11 +6547,12 @@ kernel void kernel_mul_mv_iq3_k_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_k_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq4_k_f32_impl(