diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-18 20:08:28 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | 927e251a12fa287e13c6bd9667ee97d783486c09 (patch) | |
tree | 90ed8827fc28630f52e92d8b8ea664198a6f5829 | |
parent | 181fd9c56eaa64d0a92f9e8be7387f409cfa8745 (diff) |
Bitnet(1.75 bpw): higher precision fp8 scale
Use 3 bits for the exponent and 5 bits for the mantissa.
This makes PPL to be the same as fp16 (but the previous
version with 4 bits for the exponent and mantissa was
good enough for any practical purposes).
-rw-r--r-- | ggml-cuda/common.cuh | 6 | ||||
-rw-r--r-- | ggml-cuda/convert.cu | 7 | ||||
-rw-r--r-- | ggml-cuda/vecdotq.cuh | 7 | ||||
-rw-r--r-- | ggml-metal.metal | 40 | ||||
-rw-r--r-- | iqk-quantize.cpp | 42 | ||||
-rw-r--r-- | iqk-quantize.h | 46 | ||||
-rw-r--r-- | iqk_mul_mat.cpp | 17 |
7 files changed, 81 insertions, 84 deletions
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 892fd5a6..1c2d7215 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -501,6 +501,12 @@ static __device__ __forceinline__ float get_alibi_slope( return powf(base, exph); } +static __device__ __forceinline__ float iq1bn_fp8_to_float(uint8_t fp8) { + typedef union { float f; uint32_t i; } scale_t; + scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18); + return s.f; +} + template <ggml_type type> struct ggml_cuda_type_traits; diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 13f9f246..2a897738 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -432,11 +432,8 @@ static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst int64_t i = QK_K/QK_IQ1BN * ii + ib/(QK_IQ1BN/32); if (i >= nb64) return; ib = ib%(QK_IQ1BN/32); - typedef union { float f; uint32_t i; } scale_t; - scale_t s; - uint8_t u = x[i].extra & 0xff; - s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - const float dl = x[i].extra & (1 << (4*ib + il + 8)) ? -s.f : s.f; + float d = iq1bn_fp8_to_float(x[i].extra & 0xff); + const float dl = x[i].extra & (1 << (4*ib + il + 8)) ? -d : d; const float ml = -dl; uint16_t idx = x[i].ql[4*ib + il] | ((x[i].qh[2*ib + il/2] << (8 - 4*(il%2))) & 0x0f00); const uint16_t gp = iq1bn_grid_u16[idx]; diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 6ec2035a..764a19d7 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -1078,10 +1078,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx; - typedef union { float f; uint32_t i; } scale_t; - scale_t s; - uint8_t u = bq1->extra & 0xff; - s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + float d = iq1bn_fp8_to_float(bq1->extra & 0xff); uint8_t extra = bq1->extra >> (8 + 4*iqs); int sumi = 0; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1110,7 +1107,7 @@ static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1( q8 += 8; } #endif - return s.f * __low2float(bq8_1[iqs].ds) * sumi; + return d * __low2float(bq8_1[iqs].ds) * sumi; } // TODO diff --git a/ggml-metal.metal b/ggml-metal.metal index e5ef552c..43d339c0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4992,6 +4992,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } +static inline float iq1bn_fp8_to_float(uint8_t fp8) { + typedef union { float f; uint32_t i; } scale_t; + scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18); + return s.f; +} + void kernel_mul_mv_iq1_bn_f32_impl( device const void * src0, device const float * src1, @@ -5036,13 +5042,8 @@ void kernel_mul_mv_iq1_bn_f32_impl( device const float * y4 = y + 32 * ix + 8 * ir; - typedef union { float f; uint32_t i; } scale_t; - scale_t scale; - for (int row = 0; row < N_DST; ++row) { - uint8_t u = x[nb*row].extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - d1bn[row] = scale.f; + d1bn[row] = iq1bn_fp8_to_float(x[nb*row].extra & 0xff); } uint32_t aux32[2]; @@ -5138,9 +5139,6 @@ void kernel_mul_mv_iq2_bn_f32_impl( device const float * y4 = y + 64 * ix + 4 * ir; - typedef union { float f; uint32_t i; } scale_t; - scale_t scale; - for (int row = 0; row < N_DST; ++row) { d1bn[row] = x[nb*row].d; } @@ -5945,15 +5943,10 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & template <typename type4x4> void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) { // il is in 0...3 - typedef union { float f; uint32_t i; } scale_t; - scale_t scale; - uint8_t u = xb->extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - //uint32_t u = xb->extra & 0xff; - //scale.i = (u << 19) + 905969664; + const float d = iq1bn_fp8_to_float(xb->extra & 0xff); uint8_t gs = xb->extra >> (8 + 2*il); - const float d1 = gs & 1 ? -scale.f : scale.f; - const float d2 = gs & 2 ? -scale.f : scale.f; + const float d1 = gs & 1 ? -d : d; + const float d2 = gs & 2 ? -d : d; uint32_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; uint32_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; @@ -5969,19 +5962,6 @@ void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 reg[2][i] = d2*aux8[2] - d2; reg[3][i] = d2*aux8[3] - d2; } - - //Basically same performance as above. I guess, the compiler makes the transformation automatically - //uint16_t v1 = iq1bn_grid_u16[xb->ql[2*il+0] | ((xb->qh[il] << 8) & 0x0f00)]; - //uint16_t v2 = iq1bn_grid_u16[xb->ql[2*il+1] | ((xb->qh[il] << 4) & 0x0f00)]; - //for (int i = 0; i < 4; ++i) { - // reg[0][i] = d1*((v1 >> 2*i) & 3) - d1; - // reg[2][i] = d2*((v2 >> 2*i) & 3) - d2; - //} - //v1 >>= 8; v2 >>= 8; - //for (int i = 0; i < 4; ++i) { - // reg[1][i] = d1*((v1 >> 2*i) & 3) - d1; - // reg[3][i] = d2*((v2 >> 2*i) & 3) - d2; - //} } template <typename type4x4> diff --git a/iqk-quantize.cpp b/iqk-quantize.cpp index 42e5a264..522ab2cd 100644 --- a/iqk-quantize.cpp +++ b/iqk-quantize.cpp @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "iqk-quantize.h" #include "ggml-quants.h" #include "ggml-impl.h" #define GGML_COMMON_IMPL_C @@ -81,10 +82,6 @@ IQ1BNData::IQ1BNData() { } struct IQ1BNQuantizer { - typedef union { - float f; - uint32_t i; - } scale_t; constexpr static int block_size = QK_IQ1BN; int8_t L[QK_IQ1BN]; void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix); @@ -128,22 +125,11 @@ void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, i auto max_in_row = row_max(n_per_row, src); - max_in_row *= 1.03125f; // i.e., round to nearest in our fp8 representation - scale_t s; - uint8_t u = 0; - if (max_in_row > 1.9074e-06f && max_in_row < 0.12109f) { - s.f = max_in_row; - u = ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf); - s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - } else { - // outside the allowed range. Small values we can habdle via quants set to zero, so we only warn about too large values - if (max_in_row >= 0.12109f) { - u = 255; - fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row); - } else{ - u = 0; - } + max_in_row *= 1.015625f; // i.e., round to nearest in our fp8 representation + if (max_in_row > iq1bn_max_value()) { + fprintf(stderr, "%s: found scale %g, which is outside the range of out fp8 representation\n", __func__, max_in_row); } + auto u = iq1bn_float_to_fp8(max_in_row); for (int ib = 0; ib < nblock; ++ib) { std::memset(&y[ib], 0, sizeof(block_iq1_bn)); @@ -205,12 +191,8 @@ void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) { assert(k%QK_IQ1BN == 0); int nblock = k / QK_IQ1BN; - IQ1BNQuantizer::scale_t s; - for (int i = 0; i < nblock; ++i) { - uint16_t u = x[i].extra & 0xff; - s.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); - float d = s.f; + float d = iq1bn_fp8_to_float(x[i].extra & 0xff); uint8_t extra = x[i].extra >> 8; auto qh = x[i].qh; auto ql = x[i].ql; @@ -276,11 +258,9 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz int nblock = n / QK_IQ1BN; float sumf = 0; - IQ1BNQuantizer::scale_t scale; for (int i = 0; i < nblock; ++i) { - uint16_t u = x[i].extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + float d = iq1bn_fp8_to_float(x[i].extra & 0xff); uint8_t extra = x[i].extra >> 8; auto qh = x[i].qh; auto ql = x[i].ql; @@ -304,7 +284,7 @@ void ggml_vec_dot_iq1_bn_q8_0 (int n, float * s, size_t bs, const void * vx, siz sumi2 += extra & (1 << k) ? -sl : sl; q8 += 8; } - sumf += scale.f * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2); + sumf += d * (GGML_FP16_TO_FP32(y[2*i+0].d) * sumi1 + GGML_FP16_TO_FP32(y[2*i+1].d) * sumi2); } *s = sumf; @@ -325,10 +305,8 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si int nblock = n / QK_IQ1BN; float sumf = 0; - IQ1BNQuantizer::scale_t scale; - uint16_t u = x[0].extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + float d = iq1bn_fp8_to_float(x[0].extra & 0xff); for (int i = 0; i < nblock; ++i) { uint8_t extra = x[i].extra >> 8; auto qh = x[i].qh; @@ -351,7 +329,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si sumi += extra & (1 << k) ? -sl : sl; q8 += 8; } - sumf += scale.f * (y[i].d) * sumi; + sumf += d * (y[i].d) * sumi; } *s = sumf; diff --git a/iqk-quantize.h b/iqk-quantize.h new file mode 100644 index 00000000..b89c9427 --- /dev/null +++ b/iqk-quantize.h @@ -0,0 +1,46 @@ +#pragma once + +#include <stdint.h> + +typedef union { + float f; + uint32_t i; +} iq1bn_scale_t; + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef BITNET_IQ1BN_4x4 +static inline float iq1bn_min_value(void) { return 1.9074e-06f; } +static inline float iq1bn_max_value(void) { return 0.12109f; } +#else +static inline float iq1bn_min_value(void) { return 0.000488281f; } +static inline float iq1bn_max_value(void) { return 0.123047f; } +#endif + +static inline uint8_t iq1bn_float_to_fp8(float f) { + if (f <= iq1bn_min_value()) return 0; + if (f >= iq1bn_max_value()) return 255; + iq1bn_scale_t s; + s.f = f; +#ifdef BITNET_IQ1BN_4x4 + return ((((s.i >> 23) + 132) & 0xf) << 4) | ((s.i >> 19) & 0xf); +#else + return ((s.i >> 18) & 0x1f) | (((s.i >> 23) - 116) << 5); +#endif +} + +static inline float iq1bn_fp8_to_float(uint8_t fp8) { + iq1bn_scale_t s; +#ifdef BITNET_IQ1BN_4x4 + s.i = ((((fp8 >> 4) | 0xf0) - 132) << 23) | ((fp8 & 0x0f) << 19); +#else + s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18); +#endif + return s.f; +} + +#ifdef __cplusplus +} +#endif diff --git a/iqk_mul_mat.cpp b/iqk_mul_mat.cpp index f4294d31..08f954e1 100644 --- a/iqk_mul_mat.cpp +++ b/iqk_mul_mat.cpp @@ -31,6 +31,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "iqk_mul_mat.h" +#include "iqk-quantize.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -1344,15 +1345,11 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const //auto step = bx / sizeof(block_iq1_bn); const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); - typedef union { float f; uint32_t i; } scale_t; - - scale_t scale; for (int ix = 0; ix < nrc_x; ++ix) { x = (const block_iq1_bn *)((const char *)vx + ix*bx); - uint16_t u = x[0].extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); @@ -1401,7 +1398,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, scale.f * hsum_float_8(accd[iy])); + info.store(ix, iy, d1 * hsum_float_8(accd[iy])); } } @@ -4128,15 +4125,11 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx); - typedef union { float f; uint32_t i; } scale_t; - - scale_t scale; for (int ix = 0; ix < nrc_x; ++ix) { x = (const block_iq1_bn *)((const char *)vx + ix*bx); - uint16_t u = x[0].extra & 0xff; - scale.i = ((((u >> 4) | 0xf0) - 132) << 23) | ((u & 0x0f) << 19); + float d1 = iq1bn_fp8_to_float(x[0].extra & 0xff); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_f32(0.f); @@ -4186,7 +4179,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, scale.f * vaddvq_f32(accd[iy])); + info.store(ix, iy, d1 * vaddvq_f32(accd[iy])); } } |