summaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal40
1 files changed, 10 insertions, 30 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index e5ef552c..43d339c0 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -4992,6 +4992,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
}
}
+static inline float iq1bn_fp8_to_float(uint8_t fp8) {
+ typedef union { float f; uint32_t i; } scale_t;
+ scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18);
+ return s.f;
+}
+
void kernel_mul_mv_iq1_bn_f32_impl(
device const void * src0,
device const float * src1,
@@ -5036,13 +5042,8 @@ void kernel_mul_mv_iq1_bn_f32_impl(
device const float * y4 = y + 32 * ix + 8 * ir;
- typedef union { float f; uint32_t i; } scale_t;
- scale_t scale;
-
for (int row = 0; row < N_DST; ++row) {
- uint8_t u = x[nb*row].extra & 0xff;
- scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
- d1bn[row] = scale.f;
+ d1bn[row] = iq1bn_fp8_to_float(x[nb*row].extra & 0xff);
}
uint32_t aux32[2];
@@ -5138,9 +5139,6 @@ void kernel_mul_mv_iq2_bn_f32_impl(
device const float * y4 = y + 64 * ix + 4 * ir;
- typedef union { float f; uint32_t i; } scale_t;
- scale_t scale;
-
for (int row = 0; row < N_DST; ++row) {
d1bn[row] = x[nb*row].d;
}
@@ -5945,15 +5943,10 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
template <typename type4x4>
void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
// il is in 0...3
- typedef union { float f; uint32_t i; } scale_t;
- scale_t scale;
- uint8_t u = xb->extra & 0xff;
- scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19);
- //uint32_t u = xb->extra & 0xff;
- //scale.i = (u << 19) + 905969664;
+ const float d = iq1bn_fp8_to_float(xb->extra & 0xff);
uint8_t gs = xb->extra >> (8 + 2*il);
- const float d1 = gs & 1 ? -scale.f : scale.f;
- const float d2 = gs & 2 ? -scale.f : scale.f;
+ const float d1 = gs & 1 ? -d : d;
+ const float d2 = gs & 2 ? -d : d;
uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
@@ -5969,19 +5962,6 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4
reg[2][i] = d2*aux8[2] - d2;
reg[3][i] = d2*aux8[3] - d2;
}
-
- //Basically same performance as above. I guess, the compiler makes the transformation automatically
- //uint16_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)];
- //uint16_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)];
- //for (int i = 0; i < 4; ++i) {
- // reg[0][i] = d1*((v1 >> 2*i) & 3) - d1;
- // reg[2][i] = d2*((v2 >> 2*i) & 3) - d2;
- //}
- //v1 >>= 8; v2 >>= 8;
- //for (int i = 0; i < 4; ++i) {
- // reg[1][i] = d1*((v1 >> 2*i) & 3) - d1;
- // reg[3][i] = d2*((v2 >> 2*i) & 3) - d2;
- //}
}
template <typename type4x4>