diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-18 07:37:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-18 07:37:10 +0100 |
commit | bdcae905c4cb0de1025a45a2bd6c2e646cc22be7 (patch) | |
tree | 94c4c82fc8729791fd42840c68f7fe34f0598886 /ggml/src/ggml-cuda | |
parent | dcdfad29f7d2b831f1c84751f00bda14cc359a84 (diff) |
Compile time option to use bf16 for qunts without MMQ kernels (#261)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 152 |
1 files changed, 117 insertions, 35 deletions
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b9baee1b..8383f2d3 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -579,9 +579,16 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = scale * ((x[i].scales[ib] & 254) - 127); const int8_t * values = iq4k_values + ((x[i].scales[ib] & 1) << 4); - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d * values[q4[j] & 0xf]; - y[j+16] = d * values[q4[j] >> 4]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d * values[q4[j] & 0xf]); + y[j+16] = __float2bfloat16(d * values[q4[j] >> 4]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[q4[j] & 0xf]; + y[j+16] = d * values[q4[j] >> 4]; + } } } @@ -610,9 +617,16 @@ static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, ds aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f); aux32[0] &= 0x0f0f0f0f; const uint8_t * aux8 = (const uint8_t *)aux32; - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d * values[aux8[j+0]]; - y[j+16] = d * values[aux8[j+4]]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d * values[aux8[j+0]]); + y[j+16] = __float2bfloat16(d * values[aux8[j+4]]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[aux8[j+0]]; + y[j+16] = d * values[aux8[j+4]]; + } } } @@ -632,9 +646,16 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_ const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1); const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1); - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d1 * values1[q4[j] & 0xf]; - y[j+16] = d2 * values2[q4[j] >> 4]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d1 * values1[q4[j] & 0xf]); + y[j+16] = __float2bfloat16(d2 * values2[q4[j] >> 4]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d1 * values1[q4[j] & 0xf]; + y[j+16] = d2 * values2[q4[j] >> 4]; + } } } @@ -656,12 +677,22 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; const uint8_t * qh = x[i].qh + 2*il; const uint8_t extra = x[i].extra >> 4*(ib64%4); - for (int j = 0; j < 2; ++j) { - const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); - y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; - y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; - y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; - y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = __float2bfloat16(dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]); + y[j+16] = __float2bfloat16(dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]); + y[j+32] = __float2bfloat16(dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]); + y[j+48] = __float2bfloat16(dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; + y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; + y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; + y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + } } } @@ -689,10 +720,17 @@ static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_ uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); - y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); - y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); - y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); - y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + y[j+ 0] = __float2bfloat16(dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0))); + y[j+16] = __float2bfloat16(dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0))); + y[j+32] = __float2bfloat16(dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0))); + y[j+48] = __float2bfloat16(dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0))); + } else { + y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); + y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); + y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); + y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + } } } @@ -713,11 +751,20 @@ static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_ const float dl4 = d * (((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 8); const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); - for (int j = 0; j < 2; ++j) { - y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; - y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; - y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; - y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]); + y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]); + y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]); + y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + } } } @@ -741,11 +788,20 @@ static __global__ void dequantize_block_iq2_ks(const void * __restrict__ vx, dst const float dl3 = d * (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16); const float dl4 = d * (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16); const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; - for (int j = 0; j < 2; ++j) { - y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; - y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]; - y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]; - y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]); + y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]); + y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]); + y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]; + } } } @@ -768,12 +824,22 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; const uint8_t * qh = x[i].qh + 2*il; const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); - for (int j = 0; j < 2; ++j) { - const uint8_t h = qh[j] >> (4*(ib128%2)); - y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; - y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; - y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; - y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + if constexpr (std::is_same_v<dst_t, nv_bfloat16>) { + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = __float2bfloat16(dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]); + y[j+32] = __float2bfloat16(dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]); + y[j+64] = __float2bfloat16(dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]); + y[j+96] = __float2bfloat16(dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; + y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; + y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; + y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + } } } @@ -1064,6 +1130,22 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return convert_to_bf16_cuda<float>; case GGML_TYPE_F16: return convert_to_bf16_cuda<half>; + case GGML_TYPE_IQ2_KS: + return dequantize_row_iq2_ks_cuda<nv_bfloat16>; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda<nv_bfloat16>; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda<nv_bfloat16>; + case GGML_TYPE_IQ4_KSS: + return dequantize_row_iq4_kss_cuda<nv_bfloat16>; + case GGML_TYPE_IQ4_KS: + return dequantize_row_iq4_ks_cuda<nv_bfloat16>; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda<nv_bfloat16>; + case GGML_TYPE_IQ5_K: + return dequantize_row_iq5_k_cuda<nv_bfloat16>; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda<nv_bfloat16>; default: return nullptr; } |