summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-metal.metal66
1 files changed, 50 insertions, 16 deletions
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 287d8563..e8f742fc 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -7473,11 +7473,11 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
template <typename type4x4>
-void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
+void dequantize_iq1_bn(half d, device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
constexpr uint16_t k_mult[5] = {81, 27, 9, 3, 1};
- constexpr half k_values[3] = {-1.h, 0.h, 1.h};
+ const half k_values[3] = {-d, 0.h, d};
for (int k = 0; k < 3; ++k) {
uint16_t q = xb->ql[3*il + k];
@@ -7496,15 +7496,22 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4
}
template <typename type4x4>
-void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4 & reg) {
+void dequantize_iq2_bn(half d, device const block_iq2_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
constexpr half k_scale[4] = {1.h, 0.25h, 0.0625h, 0.015625h};
- constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0};
- const half d = k_scale[il];
- uint8_t mask = k_mask[il];
+ const half db = d * k_scale[il];
+ const uint32_t mask = 0x03030303 << 2*il;
- for (int j = 0; j < 16; ++j) {
- reg[j/4][j%4] = d * (xb->qs[j] & mask) - 1;
+ device const uint32_t * qs = (device const uint32_t *)xb->qs;
+ uint32_t aux32;
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32;
+
+ for (int j = 0; j < 4; ++j) {
+ aux32 = qs[j] & mask;
+ reg[j][0] = db * aux8[0] - d;
+ reg[j][1] = db * aux8[1] - d;
+ reg[j][2] = db * aux8[2] - d;
+ reg[j][3] = db * aux8[3] - d;
}
}
@@ -7855,12 +7862,12 @@ struct DefaultDequantizer {
};
template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&), bool may_not_be_aligned = false>
-struct DequantizerRS{
+struct DequantizerRS {
using type4x4 = T4x4;
DequantizerRS(device const char * cx, short il = 0) : il(il) {
if (may_not_be_aligned) {
thread char * aux = (thread char *)&d;
- for (int i = 0; i < sizeof(d); ++i) aux[i] = cx[i];
+ for (int i = 0; i < int(sizeof(d)); ++i) aux[i] = cx[i];
} else {
d = *(device const Scale *)cx;
}
@@ -7883,6 +7890,33 @@ struct DequantizerRS{
Scale d;
};
+template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false>
+struct DequantizerRSBN {
+ using type4x4 = T4x4;
+ DequantizerRSBN(device const char * cx, short il = 0) : il(il) {
+ if (may_not_be_aligned) {
+ thread char * aux = (thread char *)&d;
+ for (int i = 0; i < int(sizeof(d)); ++i) aux[i] = cx[i];
+ } else {
+ d = *(device const Scale *)cx;
+ }
+ x = (device const Block *)(cx + sizeof(Scale));
+ }
+ inline void convert(thread T4x4& t) const {
+ dequantize(d, x, il, t);
+ }
+ inline void convert(int64_t ind, thread T4x4& t) {
+ dequantize(d, x + ind/nl, ind%nl, t);
+ }
+ inline void next() {
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ }
+ device const Block * x;
+ short il;
+ Scale d;
+};
+
// each block_q contains 16*nl weights
template<typename T, typename simdgroup_T8x8, typename Dequantizer>
kernel void kernel_mul_mm(device const uchar * src0,
@@ -8251,8 +8285,8 @@ template [[host_name("kernel_get_rows_iq3_k")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_iq4_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_k, QK_NL, dequantize_iq4_k>;
template [[host_name("kernel_get_rows_iq5_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq5_k, QK_NL, dequantize_iq5_k>;
template [[host_name("kernel_get_rows_iq6_k")]] kernel get_rows_q_t kernel_get_rows_q<block_iq6_k, QK_NL, dequantize_iq6_k>;
-template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
-template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
+template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRSBN<float4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
+template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRSBN<float4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
@@ -8294,8 +8328,8 @@ template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>>;
template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
-template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
-template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
+template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
+template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
@@ -8334,8 +8368,8 @@ template [[host_name("kernel_mul_mm_id_iq3_k_f32")]] kernel mat_mm_id_t kernel
template [[host_name("kernel_mul_mm_id_iq4_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq4_k, QK_NL, dequantize_iq4_k>>;
template [[host_name("kernel_mul_mm_id_iq5_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
template [[host_name("kernel_mul_mm_id_iq6_k_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
-template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
-template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
+template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
+template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;