diff options
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r-- | ggml-metal.metal | 40 |
1 files changed, 10 insertions, 30 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index e5ef552c..43d339c0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4992,6 +4992,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } +static inline float iq1bn_fp8_to_float(uint8_t fp8) { + typedef union { float f; uint32_t i; } scale_t; + scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18); + return s.f; +} + void kernel_mul_mv_iq1_bn_f32_impl( device const void * src0, device const float * src1, @@ -5036,13 +5042,8 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const float * y4 = y + 32 * ix + 8 * ir; - typedef union { float f; uint32_t i; } scale_t; - scale_t scale; - for (int row = 0; row < N_DST; ++row) { - uint8_t u = x[nb*row].extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - d1bn[row] = scale.f; + d1bn[row] = iq1bn_fp8_to_float(x[nb*row].extra & 0xff); } uint32_t aux32[2]; @@ -5138,9 +5139,6 @@ void kernel_mul_mv_iq2_bn_f32_impl( device const float * y4 = y + 64 * ix + 4 * ir; - typedef union { float f; uint32_t i; } scale_t; - scale_t scale; - for (int row = 0; row < N_DST; ++row) { d1bn[row] = x[nb*row].d; } @@ -5945,15 +5943,10 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & template <typename type4x4> void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 - typedef union { float f; uint32_t i; } scale_t; - scale_t scale; - uint8_t u = xb->extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - //uint32_t u = xb->extra & 0xff; - //scale.i = (u << 19) + 905969664; + const float d = iq1bn_fp8_to_float(xb->extra & 0xff); uint8_t gs = xb->extra >> (8 + 2*il); - const float d1 = gs & 1 ? -scale.f : scale.f; - const float d2 = gs & 2 ? -scale.f : scale.f; + const float d1 = gs & 1 ? -d : d; + const float d2 = gs & 2 ? -d : d; uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; @@ -5969,19 +5962,6 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 reg[2][i] = d2*aux8[2] - d2; reg[3][i] = d2*aux8[3] - d2; } - - //Basically same performance as above. I guess, the compiler makes the transformation automatically - //uint16_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; - //uint16_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; - //for (int i = 0; i < 4; ++i) { - // reg[0][i] = d1*((v1 >> 2*i) & 3) - d1; - // reg[2][i] = d2*((v2 >> 2*i) & 3) - d2; - //} - //v1 >>= 8; v2 >>= 8; - //for (int i = 0; i < 4; ++i) { - // reg[1][i] = d1*((v1 >> 2*i) & 3) - d1; - // reg[3][i] = d2*((v2 >> 2*i) & 3) - d2; - //} } template <typename type4x4> |