diff options
author | Andrew Chan <andrewkchan.akc@gmail.com> | 2025-05-22 23:17:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-23 09:17:52 +0300 |
commit | a1c931c30ce9c5618ec56fe93234110343111710 (patch) | |
tree | 1186fa100e56822f48d521be6df044db2fe428c0 | |
parent | 3efdd6df67cbfb9e1723d68ce704717daf6a967c (diff) |
Trellis quants with CPU inference (#441)
* WIP
* WIP
* WIP
* Testing Trellis quantization
Using 12 bits per 8 weights I get a better rmse than
iq2_xxs. I still need to see how quantizing the group-of-8
scales will affect accuracy. By AVX2 SIMDifying the search
for the best code, LLaMA-3.1-8B gets quantized in 130 seconds
on the Ryzen-7950X CPU - sluggish but still acceptable.
* Testing Trellis quantization: 4-bit quantized block scales
rmse increases by just 3%, so this is beating iq2_xss in terms
of rmse at the same 2.0625 bpw.
* Testing Trellis quantization: playing with scales and generators
* iq2_kt: quantize / dequantize
I now see that I was comparing apples to oranges:
iq2_xxs was using a weight of sigma^2/4 + x^2, while
the Trellis approach wasn't (weight = 1). Once I use the same weight,
iq2_kt is actually slightly worse than iq2_xxs in terms
of rmse, so does not look promising at this point.
Also, once each group of 8 Trellis values no longer has a
constant sum(q^2) that we can precompute, quantization
becomes significantly slower (476 seconds for LLaMA-3.1-8B).
* iq2_kt: CUDA dequantize
so we can run perplexity calcs.
As already indicated by rmse, the 2-bit trellis approach is
quite a bit worse than iq2_xxs.
* WIP
* WIP
* WIP - try larger blocks
With blocks of 32 and 16 bits per groups of 8 the brute force
seach becomes prohibitive in terms of CPU time (30+ minutes
for 8B LLaMA after SIMDifying with AVX2). The trick is to
group the points in clusters, find the nearest cluster,
and only search within the cluster.
* iq2_kt - this is better
Using blocks of 32 and 16 bits per group of 8 weights
it beats iq2_xxs in terms of PPL by a significant margin.
It is 0.0625 bpw larger, but even if we go to 15 bits per
group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still
lower.
* iq2_kt - even better
Re-quantize after determining block scales
(at the epxense of much longer quantization time).
* iq2_kt: CUDA dot product
Implemented as DMMV.
Very slow - just 81 t/s for LLaMA-3.1-8B.
Then again, Q2_K_S with forced to use DMMV only
gets 112 t/s vs 145 t/s via MMVQ. My memory is that
when the DMMV kernels were properly maintained/used,
DMMV was about on par with MMVQ for k-quants on my GPU.
* iq2_kt: very slightly faster CUDA dot product
* iq2_kt: f16 CUDA dot product
We arrive at 112 t/s.
* iq2_kt: faster f16 CUDA dot product
We arrive at 139 t/s (no FA), and 149 t/s (FA).
My RTX-4080 is ~20% slower than the RTX-6000 quoted in the
QTIP repository, so with FA (which I'm sure they also used)
we are at around ~180 t/s on their GPU, so almost matching
their performance.
* iq2_kt: faster f16 CUDA dot product
We arrive at 146 t/s (no FA), and 158 t/s (FA).
This is measured for LLaMA-3.1-8B with output.weight
left as f16.
* Minor
* Adding iq3_kt
3.125 bpw. So far does not look good on the PPL vs bpw plot.
* Forgotten change
* WIP
* WIP
* iq3_kt WIP: slowly improving
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is
starting to be competitive/slightly better than other quants.
* WIP
* iq3_kt WIP: slowly improving
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892
* iq3_kt WIP: slowly improving
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking
by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
* iq3_kt WIP: speed up quantization
Nearly 60% improvement of quantization speed by having the
points nelonging to a cluster copied to contiguous memory
during initialization, and then accessed sequantially while
searching for the closest point. LLaMA-3.1-8B now gets
quantized in ~150 seconds on the Ryzen-5975WX.
* iq3_kt speed up quantization
Same trick as last commit applied to iq2_kt. Here we get
an even larger speedup: quantization time on the Ryzen-5975WX
for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
* iq3_kt: CUDA dot product
* iq2_kt: SOTA
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406
PPL(LLaMA-2-7B, 4096) = 6.4179
* iq2_kt: SOTA
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642
PPL(LLaMA-2-7B, 4096) = 6.3920
* Adding iq4_kt - not competitive at this point
* WIP
* WIP
* iq4_kt: CUDA dot product
* iq4_kt: minor tweaks
* iq2_kt: SOTA
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642
PPL(LLaMA-2-7B, 4096) = 6.3920
* iq2_kt: SOTA
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297
PPL(LLaMA-2-7B, 4096) = 6.3913
Ah, quantization is faster too. About 20% faster.
* iq3_kt: small improvements and faster quantization
* iq2_kt: SOTA
We arrive at
PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627
PPL(LLaMA-2-7B, 4096) = 6.3825
Quantization is faster too: ~200 seconds for LLaMA-3.1-8B
on Ryzen-5975WX.
* iq3_kt: small progress
* WIP
* iq4_kt: go to 4.0 bpw
15 bits per group of 4, plus 8 bit scales ifor blocks of 32.
This gives a slightly better PPL than iq4_kss.
* iq4_kt: very slightly better
at the expense of much longer quantization time.
* iq4_kt: failed attemt to adjust CUDA dot product
It was working for 4.125 bpw. But after changing to 4.0 bpw
there is something wrong and I don't see the bug.
* DRY
* DRY
* iq4_kt: CUDA dot product works
* DRY
* Report actual bpw
* Minor tweaks
* Checkpoint
Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude
plus 1 bpw for the sign. It goves a visible improvement in the
PPL vs bpw plot, but that comes at the expense of much longer
quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX).
I also notices that the 3INST generator is not actually generating a
Gaussian distribution. But going to a better generator means
readjusting all the hyper-parameters, so leaving it for later.
* WIP for IQ2_KT
* WIP - working basic iq2_kt
* still super slow (0.17t/s eval)
* flatten 3inst iters + avx2 (0.3t/s eval)
* iq3_kt (0.3t/s eval) and renames
* wip buggy iq4_KT
* fix (0.22t/s eval)
* naming and remove unused fn
* cleanup
* more cleanup
* delete unused and noncompiling mmvq functions
* Some performance tweaks
* Slighty faster iq2_kt
* port Trellis struct to iq3_kt, iq4_kt
* oops untracked files
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | examples/quantize-stats/CMakeLists.txt | 17 | ||||
-rw-r--r-- | examples/quantize-stats/quantize-stats.cpp | 595 | ||||
-rw-r--r-- | examples/quantize/quantize.cpp | 3 | ||||
-rw-r--r-- | ggml/include/ggml.h | 6 | ||||
-rw-r--r-- | ggml/src/CMakeLists.txt | 2 | ||||
-rw-r--r-- | ggml/src/ggml-common.h | 18 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 4 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/common.cuh | 7 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/convert.cu | 128 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/dmmv.cu | 264 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cu | 38 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/mmvq.cuh | 1 | ||||
-rw-r--r-- | ggml/src/ggml-quants.c | 3 | ||||
-rw-r--r-- | ggml/src/ggml.c | 66 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.cpp | 403 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_gemm_ktquants.h | 11 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 7 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 1386 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 18 | ||||
-rw-r--r-- | include/llama.h | 3 | ||||
-rw-r--r-- | src/llama.cpp | 73 |
21 files changed, 3028 insertions, 25 deletions
diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt index bb986a71..ae74f016 100644 --- a/examples/quantize-stats/CMakeLists.txt +++ b/examples/quantize-stats/CMakeLists.txt @@ -1,6 +1,21 @@ +set(ARCH_FLAGS "") +if (NOT MSVC) + list(APPEND ARCH_FLAGS -march=native) +endif() +message(STATUS "ARCH_FLAGS = ${ARCH_FLAGS}") +#if (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR +# (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND +# CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$")) +# message(STATUS "x86 detected") +# if (NOT MSVC) +# list(APPEND ARCH_FLAGS -march=native) +# endif() +#endif() + +add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>") set(TARGET llama-quantize-stats) add_executable(${TARGET} quantize-stats.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ../../common) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 79905f54..a49ebd92 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -29,6 +29,7 @@ #include <thread> #include <mutex> #include <array> +#include <random> #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -48,6 +49,10 @@ constexpr int popcount(uint32_t x) { return __builtin_popcount(x); } constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); } #endif +#ifdef __AVX2__ +#include <immintrin.h> +#endif + struct quantize_stats_params { std::string model = DEFAULT_MODEL_PATH; bool verbose = false; @@ -253,6 +258,575 @@ static void test_roundtrip_on_layer( } } +static inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static const int8_t scale_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +static std::vector<float> make_values(int nval, int n_per_val, float scale = 16.f) { + std::vector<float> result(nval*n_per_val); + uint16_t m16 = ggml_fp32_to_fp16(0.922f); + uint32_t m32 = (uint32_t(m16) << 16) | m16; + const uint32_t a = 89226354, b = 64248484; + float * data = result.data(); + for (int i = 0; i < nval; ++i) { + uint32_t x = i + 4096; + for (int k = 0; k < n_per_val; ++k) { + x = a*x + b; + uint32_t s = (x & 0b10001111111111111000111111111111) ^ m32; + float val = ggml_fp16_to_fp32(s & 65535) + ggml_fp16_to_fp32(s >> 16); + int ival = nearest_int(scale*val); + data[k] = ival; + } + data += n_per_val; + } + return result; +} + +#ifdef __AVX2__ +static inline float hsum_float_4(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +static inline float hsum_float_8(__m256 x) { + return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); +} +static __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +#endif + +const int8_t scale_index[241] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15 +}; +inline int best_index_scale(const int8_t * values, float x) { + int ix = (int)x - values[0]; + if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15; + ix = scale_index[ix]; + return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15; +} +inline int best_index_iq4nl(const int8_t * values, float x) { return best_index_scale(values, x); } + +static float find_best_scale(int block_size, const float * xb, const float * weight, const int8_t * values, int ntry) { + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + return amax/96.f; //120.f; //127.f; + if (!amax) return 0.f; + float d = ntry > 0 ? -max/values[0] : max/values[0]; + float id = 1/d; + float sumqx_p = 0, sumq2_p = 0; + float sumqx_m = 0, sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + d = sumqx_p/sumq2_p; + float best = d*sumqx_p; + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d*sumqx_m; + } + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx_p = sumq2_p = 0; + sumqx_m = sumq2_m = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + sumqx_p += w*q*xb[j]; + sumq2_p += w*q*q; + l = best_index_iq4nl(values, -al); + q = values[l]; + sumqx_m += w*q*xb[j]; + sumq2_m += w*q*q; + } + if (sumq2_p > 0 && sumqx_p*sumqx_p > best*sumq2_p) { + d = sumqx_p/sumq2_p; best = d * sumqx_p; + } + if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { + d = sumqx_m/sumq2_m; best = d * sumqx_m; + } + } + return d; +} + +static std::vector<float> cluster_points(const std::vector<float>& points, int ndim, int ncluster, int niter) { + if (points.size() % ndim != 0) { + printf("%s: bad input\n", __func__); return {}; + } + int npoint = points.size() / ndim; + if (npoint < 2*ncluster) { + printf("%s: bad input\n", __func__); return {}; + } + std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY)); + double Fo = 0; + for (int i = 0; i < npoint; ++i) { + auto v = points.data() + i*ndim; + for (int k = 0; k < ndim; ++k) { + Fo += v[k]*v[k]; + range[k].first = std::min(range[k].first, v[k]); + range[k].second = std::max(range[k].second, v[k]); + } + } + printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size()); + std::mt19937 rndm(1234); + float scale = 1.f/4294967296.f; + std::vector<float> result(ncluster*ndim); + for (int i = 0; i < ncluster; ++i) { + auto v = result.data() + i*ndim; + for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm(); + } + std::vector<float> sump(ncluster*ndim); + std::vector<int> counts(ncluster); + std::vector<int> which_cluster(npoint, -1); + double Flast = Fo; + for (int iter = 0; iter < niter; ++iter) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + int nchanged = 0; + double F = 0; + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + float best = INFINITY; int ibest = -1; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = result.data() + ndim*ic; + float dist2 = 0; + for (int k = 0; k < ndim; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best) { + best = dist2; ibest = ic; + } + } + if (ibest < 0) { printf("Oops.\n"); exit(1); } + F += best; + if (which_cluster[ip] != ibest) ++nchanged; + which_cluster[ip] = ibest; + ++counts[ibest]; + auto vc = sump.data() + ndim*ibest; + for (int k = 0; k < ndim; ++k) vc[k] += vp[k]; + } + if (nchanged == 0) break; + for (int ic = 0; ic < ncluster; ++ic) { + float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f; + auto vc = sump.data() + ndim*ic; + auto r = result.data() + ndim*ic; + for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm; + } + printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged); + if (iter > 1 && Flast/F - 1 < 1e-6) break; + Flast = F; + } + return result; +} + +static void analyze_x_v2(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { + constexpr int kNumVal = 1 << 15; + constexpr int kBlockSize = 32; + constexpr int kGroupSize = 8; + constexpr int kNg = kBlockSize/kGroupSize; + constexpr int kSuperBlockSize = 256; + static_assert(kNumVal%8 == 0); + static std::vector<float> codes, clusters; + static std::vector<std::vector<int>> p_in_cluster; + if (codes.empty()) { + codes = make_values(kNumVal, kGroupSize, 31.75f); + clusters = cluster_points(codes, kGroupSize, kNumVal/512, 200); + if (clusters.empty()) { printf("Oops\n"); exit(1); } + int ncluster = clusters.size()/kGroupSize; + p_in_cluster.resize(ncluster); + std::vector<int> which_cluster(4*kNumVal); + GGML_ASSERT(ncluster%8 == 0); + for (int ip = 0; ip < kNumVal; ++ip) { + auto vp = codes.data() + ip*kGroupSize; + float best[4] = {INFINITY, INFINITY, INFINITY, INFINITY}; + int ibest[4] = {-1, -1, -1, -1}; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = clusters.data() + ic*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best[0]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = best[1]; ibest[2] = ibest[1]; + best[1] = best[0]; ibest[1] = ibest[0]; + best[0] = dist2; ibest[0] = ic; + } + else if (dist2 < best[1]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = best[1]; ibest[2] = ibest[1]; + best[1] = dist2; ibest[1] = ic; + } + else if (dist2 < best[2]) { + best[3] = best[2]; ibest[3] = ibest[2]; + best[2] = dist2; ibest[2] = ic; + } + else if (dist2 < best[3]) { + best[3] = dist2; ibest[3] = ic; + } + } + GGML_ASSERT(ibest[0] >= 0 && ibest[1] >= 0 && ibest[2] >= 0 && ibest[3] >= 0); + p_in_cluster[ibest[0]].push_back(ip); + p_in_cluster[ibest[1]].push_back(ip); + p_in_cluster[ibest[2]].push_back(ip); + p_in_cluster[ibest[3]].push_back(ip); + std::memcpy(which_cluster.data() + 4*ip, ibest, 4*sizeof(int)); + } + std::vector<std::pair<float, int>> extra; + extra.reserve(kNumVal); + for (int ic = 0; ic < ncluster; ++ic) { + auto& points = p_in_cluster[ic]; + if (!points.empty() && points.size()%8 == 0) continue; + extra.clear(); + auto vc = clusters.data() + ic*kGroupSize; + for (int ip = 0; ip < kNumVal; ++ip) { + if (which_cluster[4*ip] == ic || which_cluster[4*ip+1] == ic || which_cluster[4*ip+2] == ic || which_cluster[4*ip+3] == ic) continue; + auto vp = codes.data() + ip*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + extra.push_back(std::make_pair(dist2, ip)); + } + std::sort(extra.begin(), extra.end()); + int nadd = 8*((points.size()+7)/8) - points.size(); + for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second); + GGML_ASSERT(points.size()%8 == 0); + } + auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size(); + int nzero = 0; + for (auto& points : p_in_cluster) { + min = std::min(min, points.size()); + max = std::max(max, points.size()); + if (points.empty()) ++nzero; + } + printf("%s: prepared %d clusters\n", __func__, ncluster); + printf(" min number of points in a cluster: %d\n", int(min)); + printf(" max number of points in a cluster: %d\n", int(max)); + if (nzero > 0) { + printf(" there are %d empty clusters\n", nzero); + for (auto& points : p_in_cluster) { + if (!points.empty()) continue; + points.reserve(kNumVal); + for (int j = 0; j < kNumVal; ++j) points.push_back(j); // i.e., if we end iup picking an empty cluster, we just check all points + } + } + } + int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int chunk = (nrows + 8*nthread - 1)/(8*nthread); + std::mutex mutex; + int counter = 0; + float mse = 0, mse_q = 0; + auto compute = [&mutex, &counter, &mse, &mse_q, values, nrows, n_per_row, chunk] () { + double lmse = 0, lmse_q = 0; + std::vector<float> scales(n_per_row/kBlockSize); + std::vector<int> best_idx(n_per_row/kGroupSize); + std::vector<float> weight(kBlockSize, 1.f); + int ncluster = clusters.size() / kGroupSize; + while (true) { + std::unique_lock<std::mutex> lock(mutex); + int first = counter; counter += chunk; + if (first >= nrows) { + mse += lmse; mse_q += lmse_q; + return; + } + lock.unlock(); + int last = std::min(first + chunk, nrows); +#ifdef __AVX2__ + __m256 sqx[8]; + __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; +#endif + for (int row = first; row < last; ++row) { + auto xr = values + row*n_per_row; + float sigma2 = 0; + for (int j = 0; j < n_per_row; ++j) sigma2 += xr[j]*xr[j]; + sigma2 /= n_per_row; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + auto xb = xr + kBlockSize*ib; + //for (int i = 0; i < kBlockSize; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + float d = find_best_scale(kBlockSize, xb, weight.data(), iq4k_values, 5); + float id = d ? 1/d : 0.f; +#ifdef __AVX2__ + auto vid = _mm256_set1_ps(id); + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 8*l; + auto wl = weight.data() + 8*l; + auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl)); + auto vw = _mm256_loadu_ps(wl); + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; int jbest = -1; + for (int j = 0; j < ncluster; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(clusters.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + auto& points = p_in_cluster[jbest]; + if (points.empty()) { + printf("Oops: empty cluster %d\n", jbest); + auto vc = clusters.data() + kGroupSize*jbest; + printf("Cluster:\n"); + for (int j = 0; j < kGroupSize; ++j) printf("%d %g %g\n", j, vc[j], xl[j]); + GGML_ASSERT(false); + } + int jbest_cluster = jbest; + vbest = _mm256_set1_ps(INFINITY); + best_index = _mm256_set1_epi32(-1); + best = INFINITY; jbest = -1; + for (int j = 0; j < int(points.size()); j += 8) { + auto idx = _mm256_loadu_si256((const __m256i*)(points.data() + j)); + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*points[j+i]); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + printf("Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + GGML_ASSERT(jbest >= 0); + best_idx[ib*kNg + l] = jbest; + } + auto vqx = _mm256_setzero_ps(); + auto vq2 = _mm256_setzero_ps(); + for (int l = 0; l < kNg; ++l) { + auto vx = _mm256_loadu_ps(xb+8*l); + auto vw = _mm256_loadu_ps(weight.data() + 8*l); + auto vq = _mm256_loadu_ps(codes.data() + kGroupSize*best_idx[ib*kNg + l]); + auto vqw = _mm256_mul_ps(vq, vw); + vqx = _mm256_fmadd_ps(vqw, vx, vqx); + vq2 = _mm256_fmadd_ps(vqw, vq, vq2); + } + auto sumqx = hsum_float_8(vqx); + auto sumq2 = hsum_float_8(vq2); + scales[ib] = sumq2 > 0 ? sumqx/sumq2 : 0.f; +#else +#endif + } + float amax_scale = std::abs(scales[0]); + float max_scale = scales[0]; + for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) { + float ax = std::abs(scales[ib]); + if (ax > amax_scale) { + amax_scale = ax; + max_scale = scales[ib]; + } + } + float d = max_scale/scale_values[0]; + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + int ls = best_index_scale(scale_values, id*scales[ib]); + float dl = d * scale_values[ls]; + auto xb = xr + kBlockSize*ib; + for (int l = 0; l < kNg; ++l) { + auto q = codes.data() + kGroupSize*best_idx[ib*kNg+l]; + for (int k = 0; k < kGroupSize; ++k) { + float diff1 = xb[kGroupSize*l + k] - scales[ib]*q[k]; + float diff2 = xb[kGroupSize*l + k] - dl*q[k]; + lmse += diff1*diff1; + lmse_q += diff2*diff2; + } + } + } + } + } + }; + std::vector<std::thread> workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + tot_mse += mse; + tot_mse_q += mse_q; + tot_elements += n_per_row*nrows; + printf("%s: %g %g %g %g\n", name, sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements), + sqrt(mse_q/(n_per_row*nrows)), sqrt(tot_mse_q/tot_elements)); +} + +static void analyze_x(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_mse_q, float& tot_elements) { + constexpr int kNumVal = 1 << 12; + constexpr int kBlockSize = 8; + constexpr int kSuperBlockSize = 256; + static_assert(kNumVal%8 == 0); + auto codes = make_values(kNumVal, kBlockSize); + std::vector<float> sumq2i(kNumVal); + for (int j = 0; j < kNumVal; ++j) { + auto data = codes.data() + kBlockSize*j; + float sum = 0; for (int k = 0; k < kBlockSize; ++k) sum += data[k]*data[k]; + sumq2i[j] = sum > 0 ? 1/sum : 0.f;; + } + int nthread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int chunk = (nrows + 8*nthread - 1)/(8*nthread); + std::mutex mutex; + int counter = 0; + float mse = 0, mse_q = 0; + auto compute = [&mutex, &counter, &mse, &mse_q, &codes, &sumq2i, values, nrows, n_per_row, chunk] () { + float lmse = 0, lmse_q = 0; + std::vector<float> scales(n_per_row/kBlockSize); + std::vector<int> best_idx(n_per_row/kBlockSize); + while (true) { + std::unique_lock<std::mutex> lock(mutex); + int first = counter; counter += chunk; + if (first >= nrows) { + mse += lmse; mse_q += lmse_q; + return; + } + lock.unlock(); + int last = std::min(first + chunk, nrows); +#ifdef __AVX2__ + __m256 vx[kBlockSize/8]; + __m256 sqx[8]; + __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; +#endif + for (int row = first; row < last; ++row) { + auto xr = values + row*n_per_row; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + float best = 0, d = 0; int jbest = -1; + auto xb = xr + kBlockSize*ib; +#ifdef __AVX2__ + for (int l = 0; l < kBlockSize/8; ++l) { + vx[l] = _mm256_loadu_ps(xb+8*l); + } + auto vbest = _mm256_set1_ps(0.f); + auto best_index = _mm256_set1_epi32(-1); + for (int j = 0; j < kNumVal; j += 8) { + auto idx = _mm256_add_epi32(_mm256_set1_epi32(j), add_idx); + for (int i = 0; i < 8; ++i) { + sqx[i] = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize/8; ++l) { + auto qv = _mm256_loadu_ps(codes.data() + kBlockSize*(j+i) + 8*l); + sqx[i] = _mm256_fmadd_ps(vx[l], qv, sqx[i]); + } + } + auto sumqx = hsum_float_8x8(sqx); + auto score = _mm256_mul_ps(_mm256_mul_ps(sumqx, sumqx), _mm256_loadu_ps(sumq2i.data() + j)); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_GT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(idx, _mm256_castps_si256(mask)), _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_max_ps(vbest, score); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + best = sx[0]; jbest = index[0]; + for (int j = 1; j < 8; ++j) { + if (sx[j] > best) { best = sx[j]; jbest = index[j]; } + } + auto qv = codes.data() + kBlockSize*jbest; + float sumqx = 0; + for (int k = 0; k < kBlockSize; ++k) sumqx += xb[k]*qv[k]; + d = sumqx*sumq2i[jbest]; +#else + for (int j = 0; j < kNumVal; ++j) { + if (!sumq2i[j]) continue; + auto qv = codes.data() + kBlockSize*j; + float sumqx = 0; + for (int k = 0; k < kBlockSize; ++k) sumqx += qv[k]*xb[k]; + if (sumqx*sumqx*sumq2i[j] > best]) { + d = sumqx*sumq2i[j]; best = d*sumqx; jbest = j; + } + } + auto qv = codes.data() + kBlockSize*jbest; +#endif + scales[ib] = d; + best_idx[ib] = jbest; + for (int k = 0; k < kBlockSize; ++k) { + float diff = xb[k] - d*qv[k]; + lmse += diff*diff; + } + } + float amax_scale = std::abs(scales[0]); + float max_scale = scales[0]; + for (int ib = 1; ib < n_per_row/kBlockSize; ++ib) { + float ax = std::abs(scales[ib]); + if (ax > amax_scale) { + amax_scale = ax; + max_scale = scales[ib]; + } + } + float d = max_scale/scale_values[0]; + float id = d ? 1/d : 0.f; + for (int ib = 0; ib < n_per_row/kBlockSize; ++ib) { + int ls = best_index_scale(scale_values, id*scales[ib]); + float dl = d * scale_values[ls]; + auto xb = xr + kBlockSize*ib; + auto qv = codes.data() + kBlockSize*best_idx[ib]; + for (int k = 0; k < kBlockSize; ++k) { + float diff = xb[k] - dl*qv[k]; + lmse_q += diff*diff; + } + } + } + } + }; + std::vector<std::thread> workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + tot_mse += mse; + tot_mse_q += mse_q; + tot_elements += n_per_row*nrows; + printf("%s: %g %g %g %g\n", name, sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements), + sqrt(mse_q/(n_per_row*nrows)), sqrt(tot_mse_q/tot_elements)); +} + static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const float * values, float& tot_mse, float& tot_elements) { int row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row); int nblock = n_per_row/QK_K; @@ -302,17 +876,6 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo lmse += diff4; } else { float best = std::numeric_limits<float>::max(); - //for (int k = 0; k < 16; k += 4) { - // uint16_t v = v0 ^ (1 << k); - // uint8_t v1 = v; - // uint8_t v2 = v >> 8; - // diff1 = xb[j+ 0] - dl*values[v1 & 0xf]; - // diff2 = xb[j+16] - dl*values[v1 >> 4]; - // diff3 = xb[j+ 1] - dl*values[v2 & 0xf]; - // diff4 = xb[j+17] - dl*values[v2 >> 4]; - // float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4; - // if (score < best) best = score; - //} for (int k = 0; k < 4; ++k) { uint16_t v = (v0 >> 4*k) & 0xf; auto pc = popcount(v); @@ -345,12 +908,12 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo printf("%s: %g %g %g\n", name, sqrt(mse0/(n_per_row*nrows)), sqrt(mse/(n_per_row*nrows)), sqrt(tot_mse/tot_elements)); } -static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_elements) { +static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_mse_q, float& tot_elements) { if (!ggml_is_contiguous(t) || (t->type != GGML_TYPE_F32 && t->type != GGML_TYPE_F16 && t->type != GGML_TYPE_BF16)) { return; } if (t->type == GGML_TYPE_F32) { - analyze_iq4ks(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_elements); + analyze_x_v2(t->name, t->ne[1], t->ne[0], (const float *)t->data, tot_mse, tot_mse_q, tot_elements); } else { std::vector<float> aux(t->ne[0]*t->ne[1]); if (t->type == GGML_TYPE_F16) { @@ -358,7 +921,7 @@ static void analyze_iq4ks(const ggml_tensor * t, float& tot_mse, float& tot_elem } else { ggml_bf16_to_fp32_row((const ggml_bf16_t *)t->data, aux.data(), aux.size()); } - analyze_iq4ks(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_elements); + analyze_x_v2(t->name, t->ne[1], t->ne[0], aux.data(), tot_mse, tot_mse_q, tot_elements); } } @@ -542,7 +1105,7 @@ int main(int argc, char ** argv) { std::vector<float> output_scratch; if (analyze) { - float tot_mse = 0, tot_elements = 0; + float tot_mse = 0, tot_mse_q = 0, tot_elements = 0; for (const auto& kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; @@ -551,7 +1114,7 @@ int main(int argc, char ** argv) { // we never quantize those continue; } - analyze_iq4ks(kv_tensor.second, tot_mse, tot_elements); + analyze_iq4ks(kv_tensor.second, tot_mse, tot_mse_q, tot_elements); } return 0; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index b5277ec1..85ceabfd 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -46,6 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "Q2_K_R4", LLAMA_FTYPE_MOSTLY_Q2_K_R4, "Q2_K_S repacked", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, + { "IQ3_KT", LLAMA_FTYPE_MOSTLY_IQ3_KT, " 3.125 bpw trellis quantization", }, + { "IQ4_KT", LLAMA_FTYPE_MOSTLY_IQ4_KT, " 4.0 bpw trellis quantization", }, { "IQ3_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4,"IQ3_XXS repacked", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, { "IQ3_S_R4", LLAMA_FTYPE_MOSTLY_IQ3_S_R4, "IQ3_S repacked", }, @@ -73,6 +75,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = { { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, { "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",}, { "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",}, + { "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", }, { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, { "IQ3_K_R4", LLAMA_FTYPE_MOSTLY_IQ3_K_R4, "IQ3_K repacked", }, { "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",}, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a04c7d43..0a14ba57 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -426,6 +426,9 @@ extern "C" { GGML_TYPE_Q8_K128 = 150, GGML_TYPE_Q8_KV = 151, GGML_TYPE_IQ5_KS = 152, + GGML_TYPE_IQ2_KT = 153, + GGML_TYPE_IQ3_KT = 154, + GGML_TYPE_IQ4_KT = 155, GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, @@ -515,6 +518,9 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_KS = 141, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_KT = 142, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_KT = 143, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_KT = 144, // except 1d tensors // GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 9872b3de..b0db417d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -268,6 +268,7 @@ if (GGML_IQK_MUL_MAT) iqk/fa/iqk_fa_64_64.cpp iqk/iqk_gemm_floats.cpp iqk/iqk_gemm_kquants.cpp + iqk/iqk_gemm_ktquants.cpp iqk/iqk_gemm_iquants.cpp iqk/iqk_gemm_iqk_quants.cpp iqk/iqk_gemm_1bit.cpp @@ -277,6 +278,7 @@ if (GGML_IQK_MUL_MAT) iqk/fa/iqk_fa_templates.h iqk/iqk_gemm_floats.h iqk/iqk_gemm_kquants.h + iqk/iqk_gemm_ktquants.h iqk/iqk_gemm_iquants.h iqk/iqk_gemm_iqk_quants.h iqk/iqk_gemm_1bit.h diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 26041ac2..5fe27b29 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -621,6 +621,24 @@ typedef struct { static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding"); typedef struct { + uint8_t scales[QK_K/64]; + uint8_t ql[QK_K/4]; +} block_iq2_kt; +static_assert(sizeof(block_iq2_kt) == QK_K/4 + QK_K/64, "wrong iq2_kt block size/padding"); + +typedef struct { + uint8_t scales[QK_K/64]; + uint8_t ql[QK_K/4]; + uint8_t qh[QK_K/8]; +} block_iq3_kt; +static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding"); + +typedef struct { + uint32_t qs[QK_K/8]; +} block_iq4_kt; +static_assert(sizeof(block_iq4_kt) == QK_K/2, "wrong iq4_kt block size/padding"); + +typedef struct { ggml_half d; uint16_t extra; uint16_t scales_h; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9c8c91f4..f55715f1 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2111,6 +2111,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear + && ggml_cuda_mmvq_type_supported(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear @@ -3460,6 +3461,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ5_K: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a04a1929..896ba0df 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -565,6 +565,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> { }; template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> { static constexpr int qk = QK_K; static constexpr int qr = QR4_XS; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 5afe8c74..17604f1c 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -333,6 +333,101 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } +inline __device__ int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +float __device__ __forceinline__ trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + uint32_t s; + const half * h = (const half *)&s; + val = ka*val + kb; + s = (val & kmask) ^ km32; + return (float)(h[0]+h[1]); +} + +template<typename dst_t> +static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint16_t * ql = (const uint16_t *)x[i].ql; + uint32_t idx = ql[ib] + 4096; + const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f; + for (int j = 0; j < 8; ++j) { + y[j] = dl * trellis_next(idx); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const char * cx = (const char *)vx + row * row_size; + float scale = *(const float *)cx; + const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float)); + const int64_t i = ii - (row*n_per_row)/QK_K; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint16_t * ql = (const uint16_t *)x[i].ql; + uint32_t idx = ql[ib] + 4096; + const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f; + uint8_t mask = 1 << (ib/4); + for (int j = 0; j < 8; ++j) { + y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) { + + int64_t ii = blockIdx.x; + int64_t row = (QK_K * ii) / n_per_row; + const float * dptr = (const float *)((const char *)vx + row * row_size); + float scale = dptr[0] * 31.75f * 1.01f; + float row_av = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + const int64_t i = ii - (row*n_per_row)/QK_K; + + constexpr int kNumGroups = 64; + + const int64_t tid = threadIdx.x; + const int64_t ib = tid; // 0...31 + dst_t * y = yy + ii*QK_K + 8*ib; + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock; + const uint8_t * qh = ql + kNumGroups; + const int ib32 = ib/4; + const int ig = ib%4; + const int jj = ib32*8 + 2*ig; + uint32_t offset = shb[ib32] & 1 ? 4096 + 32768 : 4096; + uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + offset; + uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset; + int ls = ((shb[ib32] & 0xff) >> 1) - 64; + const float dl = scale * ls; + for (int j = 0; j < 4; ++j) { + y[j+0] = dl * trellis_next(idx1) + row_av; + y[j+4] = dl * trellis_next(idx2) + row_av; + } +} + template<typename dst_t> static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -969,6 +1064,27 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_ } template<typename dst_t> +static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row)); +} + +template<typename dst_t> +static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row)); +} + +template<typename dst_t> +static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { + const int64_t k = nrows * n_per_row; + const int nb = k / QK_K; + dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row)); +} + +template<typename dst_t> static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) { const int64_t k = nrows * n_per_row; const int nb = k / QK_K; @@ -1230,6 +1346,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_KT: + return dequantize_row_iq2_kt_cuda; + case GGML_TYPE_IQ3_KT: + return dequantize_row_iq3_kt_cuda; + case GGML_TYPE_IQ4_KT: + return dequantize_row_iq4_kt_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ2_S: @@ -1303,6 +1425,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q6_K_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_KT: + return dequantize_row_iq2_kt_cuda; + case GGML_TYPE_IQ3_KT: + return dequantize_row_iq3_kt_cuda; + case GGML_TYPE_IQ4_KT: + return dequantize_row_iq4_kt_cuda; case GGML_TYPE_IQ2_XS: return dequantize_row_iq2_xs_cuda; case GGML_TYPE_IQ2_S: diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 12738240..50e6458d 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "dmmv.cuh" #include "dequantize.cuh" #include "convert.cuh" @@ -8,6 +15,220 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +static __device__ __forceinline__ uint32_t trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + val = ka*val + kb; + return (val & kmask) ^ km32; +} + +static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) { + const half * h = (const half *)s; + s[0] = trellis_next(val1); + s[1] = trellis_next(val1); + s[2] = trellis_next(val2); + s[3] = trellis_next(val2); +#ifdef GGML_CUDA_F16 + bdot1 = __hfma2(y[ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); + bdot2 = __hfma2(y[64], {h[4]+h[5], h[6]+h[7]}, bdot2); +#else + bdot1.x += y[ 0].x * (float)(h[0] + h[1]); + bdot1.y += y[ 0].y * (float)(h[2] + h[3]); + bdot2.x += y[64].x * (float)(h[4] + h[5]); + bdot2.y += y[64].y * (float)(h[6] + h[7]); +#endif +} + +static __device__ __forceinline__ void trellis_accum_abs(uint8_t signs1, uint8_t signs2, uint8_t mask1, uint8_t mask2, + uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) { + const half * h = (const half *)s; + s[0] = trellis_next(val1); + s[1] = trellis_next(val1); + s[2] = trellis_next(val2); + s[3] = trellis_next(val2); +#ifdef GGML_CUDA_F16 + half h00 = __habs(h[0]+h[1]), h01 = __habs(h[2]+h[3]); + half h10 = __habs(h[4]+h[5]), h11 = __habs(h[6]+h[7]); + half2 h1 = {signs1 & mask1 ? -h00 : h00, signs2 & mask1 ? -h01 : h01}; + half2 h2 = {signs1 & mask2 ? -h10 : h10, signs2 & mask2 ? -h11 : h11}; + bdot1 = __hfma2(y[ 0], h1, bdot1); + bdot2 = __hfma2(y[64], h2, bdot2); +#else + bdot1.x += y[ 0].x * fabsf((float)(h[0] + h[1])) * (signs1 & mask1 ? -1 : 1); + bdot1.y += y[ 0].y * fabsf((float)(h[2] + h[3])) * (signs2 & mask1 ? -1 : 1); + bdot2.x += y[64].x * fabsf((float)(h[4] + h[5])) * (signs1 & mask2 ? -1 : 1); + bdot2.y += y[64].y * fabsf((float)(h[6] + h[7])) * (signs2 & mask2 ? -1 : 1); +#endif +} + +static __device__ __forceinline__ void trellis_accum(const dfloat2& dl1, const dfloat2& dl2, const dfloat2& bdot1, const dfloat2& bdot2, dfloat2& tmp) { +#ifdef GGML_CUDA_F16 + tmp = __hfma2(dl1, bdot1, tmp); + tmp = __hfma2(dl2, bdot2, tmp); +#else + tmp.x += dl1.x * bdot1.x + dl2.x * bdot2.x; + tmp.y += dl1.y * bdot1.y + dl2.y * bdot2.y; +#endif +} + +static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const float * dptr = (const float *)((const char *)vx + row*row_size); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + const int num_blocks_per_row = ncols / QK_K; + + dfloat2 tmp = {}; + + const int it = threadIdx.x/2; + const int ix = threadIdx.x%2; + + uint32_t s[4]; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); + const uint16_t * ql = (const uint16_t *)x[i].ql; + const dfloat scale1 = iq4k_values[x[i].scales[it/4] & 0xf]; + const dfloat scale2 = iq4k_values[x[i].scales[it/4] >> 4]; + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[it+ 0] + 4096; + uint32_t val2 = ql[it+16] + 4096; + for (int k = 0; k < 4; ++k) { + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); + } + trellis_accum(dl1, dl2, bdot1, bdot2, tmp); + } + + // sum up partial sums and write back result + tmp = warp_reduce_sum(tmp); + + if (threadIdx.x == 0) { + dst[row] = d * (float)(tmp.x + tmp.y); + } +} + +static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const float * dptr = (const float *)((const char *)vx + row*row_size); + const float d = *dptr * 31.75f * 1.015f; + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + const int num_blocks_per_row = ncols / QK_K; + + dfloat2 tmp = {}; + + const int it = threadIdx.x/2; + const int ix = threadIdx.x%2; + + uint32_t s[4]; + + uint8_t mask1 = 1 << (it/4); + uint8_t mask2 = mask1 << 4; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + const dfloat scale1 = (x[i].scales[it/4] & 0xf); + const dfloat scale2 = (x[i].scales[it/4] >> 4); + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[it+ 0] + 4096; + uint32_t val2 = ql[it+16] + 4096; + for (int k = 0; k < 4; ++k) { + trellis_accum_abs(qh[(8*it+2*k+0)%32], qh[(8*it+2*k+1)%32], mask1, mask2, val1, val2, s, y+k, bdot1, bdot2); + } + trellis_accum(dl1, dl2, bdot1, bdot2, tmp); + } + + // sum up partial sums and write back result + tmp = warp_reduce_sum(tmp); + + if (threadIdx.x == 0) { + dst[row] = d * (float)(tmp.x + tmp.y); + } +} + +static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { + + constexpr int kNumGroups = 64; + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const float * dptr = (const float *)((const char *)vx + row*row_size); + const float d = dptr[0] * 31.75f * 1.01f; + const float row_av = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + const int num_blocks_per_row = ncols / QK_K; + + dfloat2 tmp1 = {}; + dfloat2 tmp2 = {}; + + const int it = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; // 0 or 1 + + uint32_t s[4]; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + const uint32_t offset1 = 4096 + ((shb[it/4+0] & 1) << 15); + const uint32_t offset2 = 4096 + ((shb[it/4+4] & 1) << 15); + const dfloat scale1 = (int)((shb[it/4+0] & 0xff) >> 1) - 64; + const dfloat scale2 = (int)((shb[it/4+4] & 0xff) >> 1) - 64; + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + const uint32_t sh1 = shb[it/4+0] >> (8 + 6*(it%4)); + const uint32_t sh2 = shb[it/4+4] >> (8 + 6*(it%4)); + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[2*it+ 0] + ((qh[2*it+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + uint32_t val2 = ql[2*it+32] + ((qh[2*it+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + uint32_t val3 = ql[2*it+ 1] + ((qh[2*it+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + uint32_t val4 = ql[2*it+33] + ((qh[2*it+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + for (int k = 0; k < 2; ++k) { + trellis_accum(val1, val2, s, y+k+0, bdot1, bdot2); + trellis_accum(val3, val4, s, y+k+2, bdot1, bdot2); +#ifdef GGML_CUDA_F16 + tmp2 += y[k] + y[k+2] + y[k+64] + y[k+66]; +#else + tmp2.x += y[k].x + y[k+2].x + y[k+64].x + y[k+66].x; + tmp2.y += y[k].y + y[k+2].y + y[k+64].y + y[k+66].y; +#endif + } + trellis_accum(dl1, dl2, bdot1, bdot2, tmp1); + } + + // sum up partial sums and write back result + float tmp = d * (float)(tmp1.x + tmp1.y) + row_av * (float)(tmp2.x + tmp2.y); + tmp = warp_reduce_sum(tmp); + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -554,6 +775,36 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); } +static void dequantize_mul_mat_vec_iq2_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ2_KT, ncols); + dequantize_mul_mat_vec_iq2_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size); +} + +static void dequantize_mul_mat_vec_iq3_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ3_KT, ncols); + dequantize_mul_mat_vec_iq3_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size); +} + +static void dequantize_mul_mat_vec_iq4_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + constexpr int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(32, ny, 1); + const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KT, ncols); + dequantize_mul_mat_vec_iq4_kt<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, row_size); +} + static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; @@ -615,7 +866,8 @@ void ggml_cuda_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_IQ2_KT || src0->type == GGML_TYPE_IQ3_KT || src0->type == GGML_TYPE_IQ4_KT; if (src1_convert_f16) { src1_dfloat = src1_dfloat_a.alloc(ne00); @@ -646,6 +898,15 @@ void ggml_cuda_op_dequantize_mul_mat_vec( case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_IQ2_KT: + dequantize_mul_mat_vec_iq2_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_KT: + dequantize_mul_mat_vec_iq3_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_KT: + dequantize_mul_mat_vec_iq4_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q3_K: dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; @@ -679,5 +940,6 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || + src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT || src0_type == GGML_TYPE_F16; } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 14fe2547..d0477835 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -618,3 +618,41 @@ void ggml_cuda_op_mul_mat_vec_q_id( GGML_UNUSED(src1_ddf_i); } + +bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ3_S: + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 525c6bc0..d17765f1 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); +bool ggml_cuda_mmvq_type_supported(ggml_type src0_type); void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 0e6aa677..220c0c99 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15421,6 +15421,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_Q6_0: break; case GGML_TYPE_IQ2_K: break; case GGML_TYPE_IQ2_KS: break; + case GGML_TYPE_IQ2_KT: break; + case GGML_TYPE_IQ3_KT: break; + case GGML_TYPE_IQ4_KT: break; case GGML_TYPE_IQ3_K: break; case GGML_TYPE_IQ4_K: break; case GGML_TYPE_IQ5_K: break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7cbc0056..d8025a5a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1574,6 +1574,45 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 2, }, + [GGML_TYPE_IQ2_KT] = { + .type_name = "iq2_kt", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_kt), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_kt, + .from_float = quantize_row_iq2_kt, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref, + .vec_dot = vec_dot_iq2_kt_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 4, + }, + [GGML_TYPE_IQ3_KT] = { + .type_name = "iq3_kt", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_kt), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_kt, + .from_float = quantize_row_iq3_kt, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref, + .vec_dot = vec_dot_iq3_kt_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 4, + }, + [GGML_TYPE_IQ4_KT] = { + .type_name = "iq4_kt", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_kt), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_kt, + .from_float = quantize_row_iq4_kt, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref, + .vec_dot = vec_dot_iq4_kt_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 8, + }, [GGML_TYPE_IQ3_K] = { .type_name = "iq3_k", .blck_size = QK_K, @@ -4501,6 +4540,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; case GGML_FTYPE_MOSTLY_IQ2_K_R4: wtype = GGML_TYPE_IQ2_K_R4; break; case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break; + case GGML_FTYPE_MOSTLY_IQ2_KT: wtype = GGML_TYPE_IQ2_KT; break; + case GGML_FTYPE_MOSTLY_IQ3_KT: wtype = GGML_TYPE_IQ3_KT; break; + case GGML_FTYPE_MOSTLY_IQ4_KT: wtype = GGML_TYPE_IQ4_KT; break; case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; case GGML_FTYPE_MOSTLY_IQ3_K_R4: wtype = GGML_TYPE_IQ3_K_R4; break; @@ -11266,6 +11308,9 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -11740,6 +11785,9 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -11911,6 +11959,9 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -15409,6 +15460,9 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -15820,6 +15874,9 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -16137,6 +16194,9 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -16771,6 +16831,9 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: case GGML_TYPE_IQ3_K_R4: @@ -23841,6 +23904,9 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K_R4:result = quantize_iq2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_K_R4:result = quantize_iq3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp new file mode 100644 index 00000000..c38dcdc6 --- /dev/null +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -0,0 +1,403 @@ +#include "iqk_gemm_ktquants.h" +#include "ggml.h" + +#ifdef IQK_IMPLEMENT + +#include "ggml-impl.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#ifdef __x86_64__ + +namespace { + +static inline uint32_t trellis_next(uint32_t& val) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + val = val*ka + kb; + return (val & kmask) ^ km32; +} + +static inline __m256i trellis_next8(uint32_t val) { + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t ka1 = ka*ka; + constexpr uint32_t kb1 = kb*ka+kb; + constexpr uint32_t ka2 = ka1*ka; + constexpr uint32_t kb2 = kb1*ka+kb; + constexpr uint32_t ka3 = ka2*ka; + constexpr uint32_t kb3 = kb2*ka+kb; + constexpr uint32_t ka4 = ka3*ka; + constexpr uint32_t kb4 = kb3*ka+kb; + constexpr uint32_t ka5 = ka4*ka; + constexpr uint32_t kb5 = kb4*ka+kb; + constexpr uint32_t ka6 = ka5*ka; + constexpr uint32_t kb6 = kb5*ka+kb; + constexpr uint32_t ka7 = ka6*ka; + constexpr uint32_t kb7 = kb6*ka+kb; + __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7); + __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); + __m256i mval = _mm256_set1_epi32(val); + __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32); +} + +static inline float trellis_gen(uint32_t& val, uint32_t* s) { + const ggml_fp16_t * h = (const ggml_fp16_t *)s; + s[0] = trellis_next(val); + return GGML_FP16_TO_FP32(h[0]) + GGML_FP16_TO_FP32(h[1]); +} + +struct Trellis1 { + constexpr static uint32_t kmask = 0x8fff8fff; + constexpr static uint32_t km32 = 0x3b603b60; + constexpr static uint32_t ka = 89226354; + constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t kb1 = kb*ka+kb; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t kb2 = kb1*ka+kb; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t kb3 = kb2*ka+kb; + constexpr static uint32_t ka4 = ka3*ka; + constexpr static uint32_t kb4 = kb3*ka+kb; + constexpr static uint32_t ka5 = ka4*ka; + constexpr static uint32_t kb5 = kb4*ka+kb; + constexpr static uint32_t ka6 = ka5*ka; + constexpr static uint32_t kb6 = kb5*ka+kb; + constexpr static uint32_t ka7 = ka6*ka; + constexpr static uint32_t kb7 = kb6*ka+kb; + const __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka4, ka5, ka6, ka7); + const __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb4, kb5, kb6, kb7); + const __m256i mask1 = _mm256_set1_epi32(kmask); + const __m256i mask2 = _mm256_set1_epi32(km32); + + inline __m256i next8(uint32_t val) const { + auto mval = _mm256_set1_epi32(val); + auto mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_and_si256(mres, mask1) ^ mask2; + } +}; + +static inline __m256 trellis_gen8(__m256i i8) { + // split upper and lower bits of each 32-bit lane into two 8xfloat16 `hlo`, `hhi` + __m256i low_16_bits_mask = _mm256_set1_epi32(0x0000FFFF); + __m256i lower_halves_lanes32 = _mm256_and_si256(i8, low_16_bits_mask); + __m256i upper_halves_lanes32 = _mm256_srli_epi32(i8, 16); + // 00L0, 00L1, 00L2, 00L3, 00H0, 00H1, 00H2, 00H3, 00L4, 00L5, 00L6, 00L7, 00H4, 00H5, 00H6, 00H7 + auto iv = _mm256_packus_epi32(lower_halves_lanes32, upper_halves_lanes32); + // 00L0, 00L1, 00L2, 00L3, 00L4, 00L5, 00L6, 00L7, 00H0, 00H1, 00H2, 00H3, 00H4, 00H5, 00H6, 00H7 + iv = _mm256_permute4x64_epi64(iv, 0xd8); + auto fv1 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 0)); + auto fv2 = _mm256_cvtph_ps(_mm256_extracti128_si256(iv, 1)); + return _mm256_add_ps(fv1, fv2); +} + +struct Trellis2 { + constexpr static uint32_t kmask = 0x8fff8fff; + constexpr static uint32_t km32 = 0x3b603b60; + constexpr static uint32_t ka = 89226354; + constexpr static uint32_t kb = 64248484; + constexpr static uint32_t ka1 = ka*ka; + constexpr static uint32_t kb1 = kb*ka+kb; + constexpr static uint32_t ka2 = ka1*ka; + constexpr static uint32_t kb2 = kb1*ka+kb; + constexpr static uint32_t ka3 = ka2*ka; + constexpr static uint32_t kb3 = kb2*ka+kb; + __m256i mka = _mm256_setr_epi32(ka, ka1, ka2, ka3, ka, ka1, ka2, ka3); + __m256i mkb = _mm256_setr_epi32(kb, kb1, kb2, kb3, kb, kb1, kb2, kb3); + const __m256i mask1 = _mm256_set1_epi32(kmask); + const __m256i mask2 = _mm256_set1_epi32(km32); + + inline __m256i next8(uint32_t val1, uint32_t val2) { + __m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2); + __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); + return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32); + } +}; + +template <int nrc_y> +static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + + union { __m256 vec; float val[8]; } s_helper; + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + __m256 accd[k_acc]; + const float * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_cvtepi32_ps(s32); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]); + for (int j = 0; j < 4; ++j) { + auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096))); + auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096))); + if constexpr (nrc_y == 1) { + accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[0]); + accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[1]); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[iy]); + accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[iy]); + } + } + } + } + } + + if constexpr (nrc_y == 1) { + __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_add_ps(accd[0], accd[1])); + info.store(ix, 0, hsum_float_8(res)); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]); + info.store(ix, iy, hsum_float_8(res)); + } + } + } +} + +static inline __m256 abs_ps(__m256 vals) { + // Clear sign-bit of all the 32-bit floats in vals + __m256 sign_bit = _mm256_set1_ps(-0.0f); + return _mm256_andnot_ps(sign_bit, vals); +} + +// Negates 32-bit float lanes of an 8x32-bit vector +// based on 8x8-bit condition var. For float lane i, if byte i of +// `condition` is nonzero, the float will be negated. +static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) { + __m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64); + // Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero, + // else 0x00 (upper bytes are meaningless) + __m128i zeros = _mm_setzero_si128(); + __m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros); + __m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros); + // Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros + // expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise + __m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask); + // Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask) + __m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256()); + // Negate via XOR on sign bits of each 32-bit float + __m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value + __m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern); + __m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32); + return _mm256_xor_ps(vals, xor_mask_ps); +} + +template <int nrc_y> +static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + __m256 accd[nrc_y]; + const float * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.015f; + const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 1); + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + const uint8_t * qh = x[i].qh; + for (int j = 0; j < 128; j+=8) { + uint64_t mask1 = 0x0101010101010101 << (j/32); + uint64_t mask2 = mask1 << 4; + uint32_t val1 = ql[j/8] + 4096; + uint32_t val2 = ql[j/8+16] + 4096; + const uint64_t signs = *((const uint64_t *)(qh + (j%32))); + const float x_scale1 = (x[i].scales[j/32] & 0xf); + const float x_scale2 = (x[i].scales[j/32] >> 4); + const __m256 x_val1 = abs_ps(trellis_gen8(trellis.next8(val1))); + const __m256 x_val2 = abs_ps(trellis_gen8(trellis.next8(val2))); + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps( + conditional_negate_ps( + _mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1 + ), + _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), + accd[iy] + ); + accd[iy] = _mm256_fmadd_ps( + conditional_negate_ps( + _mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2 + ), + _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), + accd[iy] + ); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]); + info.store(ix, iy, hsum_float_8(res)); + } + } +} + +template <int nrc_y> +static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + constexpr int kNumGroups = 64; + + Trellis2 trellis; + + __m256 accd[nrc_y]; + __m256 accd2[nrc_y]; + const float * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = dptr[0] * 31.75f * 1.01f; + const float row_av = dptr[1]; + const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); + + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_setzero_ps(); + accd2[iy] = _mm256_setzero_ps(); + } + + for (int i = 0; i < nb; ++i) { + const uint32_t * shb = x[i].qs; + const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * qh = ql + kNumGroups; + for (int j = 0; j < 128; j+=8) { + const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15); + const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15); + const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64; + const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64; + const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); + const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); + uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3)); + const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4)); + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = _mm256_fmadd_ps( + _mm256_load_ps(y[iy] + i*QK_K+j), + _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), + accd[iy] + ); + accd[iy] = _mm256_fmadd_ps( + _mm256_load_ps(y[iy] + i*QK_K+j+128), + _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), + accd[iy] + ); + accd2[iy] = _mm256_add_ps( + _mm256_load_ps(y[iy] + i*QK_K+j), + accd2[iy] + ); + accd2[iy] = _mm256_add_ps( + _mm256_load_ps(y[iy] + i*QK_K+j+128), + accd2[iy] + ); + } + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]); + __m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]); + info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2)); + } + } +} + +} // namespace + +bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { + + if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F32) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_IQ2_KT: + assert (ne00 % QK_K == 0); + kernels[0] = mul_mat_iq2_kt_F32_T<1>; + kernels[1] = mul_mat_iq2_kt_F32_T<2>; + kernels[2] = mul_mat_iq2_kt_F32_T<3>; + kernels[3] = mul_mat_iq2_kt_F32_T<4>; + kernels[4] = mul_mat_iq2_kt_F32_T<5>; + kernels[5] = mul_mat_iq2_kt_F32_T<6>; + kernels[6] = mul_mat_iq2_kt_F32_T<7>; + kernels[7] = mul_mat_iq2_kt_F32_T<8>; + break; + case GGML_TYPE_IQ3_KT: + assert (ne00 % QK_K == 0); + kernels[0] = mul_mat_iq3_kt_F32_T<1>; + kernels[1] = mul_mat_iq3_kt_F32_T<2>; + kernels[2] = mul_mat_iq3_kt_F32_T<3>; + kernels[3] = mul_mat_iq3_kt_F32_T<4>; + kernels[4] = mul_mat_iq3_kt_F32_T<5>; + kernels[5] = mul_mat_iq3_kt_F32_T<6>; + kernels[6] = mul_mat_iq3_kt_F32_T<7>; + kernels[7] = mul_mat_iq3_kt_F32_T<8>; + break; + case GGML_TYPE_IQ4_KT: + assert (ne00 % QK_K == 0); + kernels[0] = mul_mat_iq4_kt_F32_T<1>; + kernels[1] = mul_mat_iq4_kt_F32_T<2>; + kernels[2] = mul_mat_iq4_kt_F32_T<3>; + kernels[3] = mul_mat_iq4_kt_F32_T<4>; + kernels[4] = mul_mat_iq4_kt_F32_T<5>; + kernels[5] = mul_mat_iq4_kt_F32_T<6>; + kernels[6] = mul_mat_iq4_kt_F32_T<7>; + kernels[7] = mul_mat_iq4_kt_F32_T<8>; + break; + default: + return false; + } + + return true; + +} + +#else // !__x86_64__ + +bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) { + return false; +} + +#endif + +#endif
\ No newline at end of file diff --git a/ggml/src/iqk/iqk_gemm_ktquants.h b/ggml/src/iqk/iqk_gemm_ktquants.h new file mode 100644 index 00000000..b1e84d63 --- /dev/null +++ b/ggml/src/iqk/iqk_gemm_ktquants.h @@ -0,0 +1,11 @@ +#pragma once + +#include "iqk_common.h" + +#ifdef IQK_IMPLEMENT + +#include <array> + +bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16); + +#endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index abf14ed0..43be0885 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -22,6 +22,7 @@ #include "iqk_flash_impl.h" #include "iqk_gemm_floats.h" #include "iqk_gemm_kquants.h" +#include "iqk_gemm_ktquants.h" #include "iqk_gemm_iquants.h" #include "iqk_gemm_iqk_quants.h" #include "iqk_gemm_1bit.h" @@ -541,6 +542,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ5_KS_R4: return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs, mm.func16); + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: + return ggml_type(typeB) == GGML_TYPE_F32 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, mm.funcs, mm.func16) : false; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -921,4 +926,4 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n return false; } -#endif +#endif
\ No newline at end of file diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 93aa2180..c1f7a8e4 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -22,6 +22,8 @@ #include <algorithm> #include <cstring> #include <mutex> +#include <random> +#include <memory> #include <thread> #include <atomic> #include <unordered_map> @@ -7408,3 +7410,1387 @@ void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) { } } +namespace { +#ifdef __AVX2__ +__m128 hsum_float_4x4(__m128 * accm) { + accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2])); + accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3])); + return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1])); +} +__m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +__m256 hsum_float_4x8(__m256 * accm) { + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +#endif +template <int block_size, int group_size, int num_bits, bool is_abs = false> +class QuantizerIQKT { + static_assert(group_size == 8 || group_size == 4); + static_assert(block_size >= 8 && block_size%8 == 0); +public: + constexpr static int kSuperBlockSize = QK_K; + constexpr static int kBlockSize = block_size; + constexpr static int kGroupSize = group_size; + constexpr static int kNg = kBlockSize/kGroupSize; + constexpr static int kNblock = kSuperBlockSize/kBlockSize; + constexpr static int kNumVal = 1 << num_bits; // i.e, 16 bits per group of 8 + constexpr static float kScale = 31.75f; + constexpr static bool kVerbose = false; + + QuantizerIQKT(int num_clusters, int num_neighbours, int offset = 4096); + const float * values() const { return m_values.data(); } + + inline void find_best_match(float d, const float * xb, const float * weight, int * best_idx) const; + inline std::pair<float, float> find_best_scale(const float * xb, const float * weight, const int * best_idx) const; + inline float find_best_inverse_scale(const float * xb, const float * weight, const int * best_idx) const; + + static inline void set_values(uint32_t i, float * result, float scale, int offset = 4096) { + constexpr uint32_t ka = 89226354; + constexpr uint32_t kb = 64248484; + constexpr uint32_t kmask = 0x8fff8fff; + constexpr uint32_t km32 = 0x3b603b60; + uint32_t x = i + offset; + for (int k = 0; k < kGroupSize; ++k) { + x = ka*x + kb; + uint32_t s = (x & kmask) ^ km32; + float val = GGML_FP16_TO_FP32(s & 65535) + GGML_FP16_TO_FP32(s >> 16); + if constexpr (is_abs) result[k] = scale*std::abs(val); + else result[k] = scale*val; + } + } + + static inline int bin4(float x) { + if constexpr (is_abs) { + return x < 16.f ? 0 : x < 32.f ? 1 : x < 64.f ? 2 : 3; + } else { + return x < -24.f ? 0 : x < 0.0f ? 1 : x < 24.f ? 2 : 3; + } + } + static inline int bin5(float x) { + if constexpr (is_abs) { + return x < 11.2f ? 0 : x < 24.f ? 1 : x < 39.f ? 2 : x < 58.f ? 3 : 4; + } else { + return x < -48.f ? 0 : x < -16.f ? 1 : x < 16.f ? 2 : x < 48.f ? 3 : 4; + } + } + inline int bin3(int idim, float x) const { return x < m_mid[2*idim+0] ? 0 : x < m_mid[2*idim+1] ? 1 : 2; } + + static inline void set_weights(float sigma2_scale, int nblock, const float * x, const float * imatrix, float * row_weights) { + for (int ibl = 0; ibl < nblock; ++ibl) { + + const float * xbl = x + ibl*kSuperBlockSize; + float * wbl = row_weights + ibl*kSuperBlockSize; + + float sumx2 = 0; + for (int j = 0; j < kSuperBlockSize; ++j) sumx2 += xbl[j]*xbl[j]; + const float sigma2 = sigma2_scale*sumx2/kSuperBlockSize; + + if (imatrix) { + const float * qw = imatrix + ibl*kSuperBlockSize; + for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = qw[j] * sqrtf(sigma2 + xbl[j]*xbl[j]); + } else { + for (int j = 0; j < kSuperBlockSize; ++j) wbl[j] = 0.25f*sigma2 + xbl[j]*xbl[j]; + } + } + } +private: + static std::vector<float> cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid); + static std::vector<std::vector<int>> finalize_clusters(int num_neighbours, const std::vector<float>& points, const std::vector<float>& clusters, + std::vector<std::vector<float>>& c_values); + std::vector<float> m_values; + std::vector<float> m_clusters; + std::vector<std::vector<int>> m_in_cluster; + std::vector<std::vector<float>> m_c_values; + float m_mid[4*kGroupSize]; +}; + +template <int block_size, int group_size, int num_bits, bool is_abs> +QuantizerIQKT<block_size, group_size, num_bits, is_abs>::QuantizerIQKT(int num_clusters, int num_neighbours, int offset) { + m_values.resize(kNumVal*kGroupSize); + float * data = m_values.data(); + for (int i = 0; i < kNumVal; ++i) { + set_values(i, data, kScale, offset); + data += kGroupSize; + } + // Make 128 clusters. + // Note: we get a slightly better result by using 64 clusters + // at the expense of almost doubling the quantization time. + m_clusters = cluster_points(m_values, num_clusters, 200, m_mid); + GGML_ASSERT(!m_clusters.empty()); + m_in_cluster = finalize_clusters(num_neighbours, m_values, m_clusters, m_c_values); +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +std::pair<float, float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_scale( + const float * xb, const float * weight, const int * best_idx) const { + float sumqx = 0, sumq2 = 0; +#ifdef __AVX2__ + auto vqx = _mm256_setzero_ps(); + auto vq2 = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize; l += 8) { + auto vx = _mm256_loadu_ps(xb+l); + auto vw = _mm256_loadu_ps(weight+l); + auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) : + _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]), + _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0])); + auto vqw = _mm256_mul_ps(vq, vw); + vqx = _mm256_fmadd_ps(vqw, vx, vqx); + vq2 = _mm256_fmadd_ps(vqw, vq, vq2); + } + sumqx = hsum_float_8(vqx); + sumq2 = hsum_float_8(vq2); +#else + for (int l = 0; l < kNg; ++l) { + auto xl = xb + kGroupSize*l; + auto wl = weight + kGroupSize*l; + auto ql = m_values.data() + kGroupSize*best_idx[l]; + for (int k = 0; k < kGroupSize; ++k) { + sumqx += wl[k]*ql[k]*xl[k]; + sumq2 += wl[k]*ql[k]*ql[k]; + } + } +#endif + return sumq2 > 0 ? std::make_pair(sumqx/sumq2, sumqx*sumqx/sumq2) : std::make_pair(0.f, 0.f); +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +float QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_inverse_scale( + const float * xb, const float * weight, const int * best_idx) const { + float sumqx = 0, sumx2 = 0; +#ifdef __AVX2__ + auto vqx = _mm256_setzero_ps(); + auto vx2 = _mm256_setzero_ps(); + for (int l = 0; l < kBlockSize; l += 8) { + auto vx = _mm256_loadu_ps(xb+l); + auto vw = _mm256_loadu_ps(weight+l); + auto vq = kGroupSize == 8 ? _mm256_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize]) : + _mm256_set_m128(_mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+1]), + _mm_loadu_ps(m_values.data() + kGroupSize*best_idx[l/kGroupSize+0])); + auto vxw = _mm256_mul_ps(vx, vw); + vx2 = _mm256_fmadd_ps(vxw, vx, vx2); + vqx = _mm256_fmadd_ps(vxw, vq, vqx); + } + sumqx = hsum_float_8(vqx); + sumx2 = hsum_float_8(vx2); +#else + for (int l = 0; l < kNg; ++l) { + auto xl = xb + kGroupSize*l; + auto wl = weight + kGroupSize*l; + auto ql = m_values.data() + kGroupSize*best_idx[l]; + for (int k = 0; k < kGroupSize; ++k) { + sumqx += wl[k]*ql[k]*xl[k]; + sumx2 += wl[k]*xl[k]*xl[k]; + } + } +#endif + return sumx2 > 0 ? sumqx/sumx2 : 0.f; +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +void QuantizerIQKT<block_size, group_size, num_bits, is_abs>::find_best_match(float d, const float * xb, const float * weight, int * best_idx) const { + if (!d) { + std::memset(best_idx, 0, kNg*sizeof(int)); + return; + } + int ncluster = m_clusters.size()/kGroupSize; + float id = 1/d; +#ifdef __AVX2__ + if constexpr (kGroupSize == 8) { + __m256 sqx[8]; + const __m256i add_idx = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + float sx[8]; + int index[8]; + auto vid = _mm256_set1_ps(id); + auto add8 = _mm256_set1_epi32(8); + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 8*l; + auto wl = weight + 8*l; + auto vx = _mm256_mul_ps(vid, _mm256_loadu_ps(xl)); + auto vw = _mm256_loadu_ps(wl); + int jbest = -1; + if (kGroupSize == 8 && (ncluster == 256 || ncluster == 6561)) { + _mm256_store_ps(sx, vx); + uint16_t u = 0; + if (ncluster == 256) { + for (int j = 0; j < 8; ++j) if (sx[j] > m_mid[j]) u |= (1 << j); + } else { + int s = 1; + for (int j = 0; j < 8; ++j) { u += s*bin3(j, sx[j]); s *= 3; } + } + jbest = u; + } else { + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; + auto idx = add_idx; + for (int j = 0; j < ncluster; j += 8) { + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + } + auto& points = m_in_cluster[jbest]; + auto& values = points.empty() ? m_values : m_c_values[jbest]; + int npoint = values.size()/kGroupSize; + GGML_ASSERT(npoint > 0 && npoint%8 == 0); + int jbest_cluster = jbest; + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + auto best = INFINITY; jbest = -1; + auto idx = add_idx; + for (int j = 0; j < npoint; j += 8) { + for (int i = 0; i < 8; ++i) { + auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_8x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + best_idx[l] = points.empty() ? jbest : points[jbest]; + } + } else { + __m256 sqx[4]; + const __m256i add_idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + const __m256 sign_bit = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)); + float sx[8]; + int index[8]; + auto vid_p = _mm256_set1_ps(id); + auto add8 = _mm256_set1_epi32(8); + for (int l = 0; l < kNg; ++l) { + auto xl = xb + 4*l; + auto wl = weight + 4*l; + auto vx4 = _mm_loadu_ps(xl); + auto vx = _mm256_mul_ps(vid_p, _mm256_set_m128(vx4, vx4)); + auto vw4 = _mm_loadu_ps(wl); + auto vw = _mm256_set_m128(vw4, vw4); + int jbest = -1; + if (ncluster == 256 || ncluster == 625) { + _mm256_storeu_ps(sx, vx); + uint16_t u = 0; + if (ncluster == 256) { + for (int k = 0; k < 4; ++k) u |= (bin4(sx[k]) << 2*k); + } else { + int l = 1; + for (int k = 0; k < 4; ++k) { u += bin5(sx[k])*l; l *= 5; } + } + jbest = u; + } else { + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; + auto idx = add_idx; + for (int j = 0; j < ncluster; j += 8) { + for (int i = 0; i < 4; ++i) { + auto vq = _mm256_loadu_ps(m_clusters.data() + kGroupSize*(j+2*i)); + auto vdiff = _mm256_sub_ps(vq, vx); + vdiff = _mm256_and_ps(sign_bit, vdiff); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, _mm256_mul_ps(vdiff, vdiff))); + } + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + } + auto& points = m_in_cluster[jbest]; + auto& values = m_c_values[jbest]; + GGML_ASSERT(!points.empty() && points.size()%8 == 0); + int jbest_cluster = jbest; + auto vbest = _mm256_set1_ps(INFINITY); + auto best_index = _mm256_set1_epi32(-1); + float best = INFINITY; jbest = -1; + auto idx = add_idx; + for (int j = 0; j < int(points.size()); j += 8) { + for (int i = 0; i < 4; ++i) { + auto vq = _mm256_loadu_ps(values.data() + kGroupSize*(j+2*i)); + auto vdiff = _mm256_sub_ps(vq, vx); + sqx[i] = _mm256_mul_ps(vw, _mm256_mul_ps(vdiff, vdiff)); + } + auto score = hsum_float_4x8(sqx); + auto mask = _mm256_cmp_ps(score, vbest, _CMP_LT_OQ); + best_index = _mm256_or_si256(_mm256_and_si256(_mm256_castps_si256(mask), idx), + _mm256_andnot_si256(_mm256_castps_si256(mask), best_index)); + vbest = _mm256_min_ps(vbest, score); + idx = _mm256_add_epi32(idx, add8); + } + _mm256_store_ps(sx, vbest); + _mm256_store_si256((__m256i *)index, best_index); + for (int i = 0; i < 8; ++i) { + if (sx[i] < best) { best = sx[i]; jbest = index[i]; } + } + if (jbest < 0) { + fprintf(stderr, "Oops: jbest = %d for cluster %d with %d points\n", jbest, jbest_cluster, int(points.size())); + GGML_ASSERT(false); + } + best_idx[l] = points[jbest]; + } + } +#else + // TODO + std::memset(best_idx, 0, kNg*sizeof(int)); +#endif +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +std::vector<std::vector<int>> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::finalize_clusters(int num_neighbours, + const std::vector<float>& values, const std::vector<float>& clusters, std::vector<std::vector<float>>& c_values) { + int ncluster = clusters.size()/kGroupSize; + std::vector<std::vector<int>> p_in_cluster(ncluster); + std::vector<int> which_cluster(num_neighbours*kNumVal); + std::vector<int> ibest(num_neighbours); + std::vector<float> best(num_neighbours); + for (int ip = 0; ip < kNumVal; ++ip) { + auto vp = values.data() + ip*kGroupSize; + for (int j = 0; j < num_neighbours; ++j) { + best[j] = INFINITY; ibest[j] = -1; + } + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = clusters.data() + ic*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + for (int j = 0; j < num_neighbours; ++j) { + if (dist2 < best[j]) { + for (int k = num_neighbours-1; k > j; --k) { + best[k] = best[k-1]; ibest[k] = ibest[k-1]; + } + best[j] = dist2; ibest[j] = ic; + break; + } + } + } + for (int j = 0; j < num_neighbours; ++j) { + if (ibest[j] < 0) { + printf("Oops: ibest[%d] = %d\n", j, ibest[j]); + } + GGML_ASSERT(ibest[j] >= 0); + p_in_cluster[ibest[j]].push_back(ip); + } + std::memcpy(which_cluster.data() + num_neighbours*ip, ibest.data(), num_neighbours*sizeof(int)); + } + std::vector<std::pair<float, int>> extra; + extra.reserve(kNumVal); + for (int ic = 0; ic < ncluster; ++ic) { + auto& points = p_in_cluster[ic]; + if (!points.empty() && points.size()%8 == 0) continue; + extra.clear(); + auto vc = clusters.data() + ic*kGroupSize; + for (int ip = 0; ip < kNumVal; ++ip) { + bool can_add = true; + for (int j = 0; j < num_neighbours; ++j) { + if (which_cluster[num_neighbours*ip+j] == ic) { can_add = false; break; } + } + if (!can_add) continue; + auto vp = values.data() + ip*kGroupSize; + float dist2 = 0; + for (int k = 0; k < kGroupSize; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + extra.push_back(std::make_pair(dist2, ip)); + } + std::sort(extra.begin(), extra.end()); + int nadd = 8*((points.size()+7)/8) - points.size(); + for (int i = 0; i < nadd; ++i) points.push_back(extra[i].second); + GGML_ASSERT(points.size()%8 == 0); + } + auto min = p_in_cluster.front().size(), max = p_in_cluster.front().size(); + for (auto& points : p_in_cluster) { + min = std::min(min, points.size()); + max = std::max(max, points.size()); + } + c_values.resize(p_in_cluster.size()); + for (int i = 0; i < int(p_in_cluster.size()); ++i) { + auto& points = p_in_cluster[i]; + c_values[i].resize(points.size()*kGroupSize); + auto ptr = c_values[i].data(); + for (auto j : points) { + std::memcpy(ptr, values.data() + j*kGroupSize, kGroupSize*sizeof(float)); + ptr += kGroupSize; + } + } + + if (kVerbose) { + printf("%s: prepared %d clusters\n", __func__, ncluster); + printf(" min number of points in a cluster: %d\n", int(min)); + printf(" max number of points in a cluster: %d\n", int(max)); + } + return p_in_cluster; +} + +template <int block_size, int group_size, int num_bits, bool is_abs> +std::vector<float> QuantizerIQKT<block_size, group_size, num_bits, is_abs>::cluster_points(const std::vector<float>& points, int ncluster, int niter, float * mid) { + constexpr int ndim = kGroupSize; + GGML_ASSERT(points.size() % ndim == 0); + int npoint = points.size() / ndim; + GGML_ASSERT(npoint >= 2*ncluster); + std::vector<std::pair<float, float>> range(ndim, std::make_pair(INFINITY, -INFINITY)); + double Fo = 0; + for (int i = 0; i < npoint; ++i) { + auto v = points.data() + i*ndim; + for (int k = 0; k < ndim; ++k) { + Fo += v[k]*v[k]; + range[k].first = std::min(range[k].first, v[k]); + range[k].second = std::max(range[k].second, v[k]); + } + } + if (kVerbose) printf("%s (ndim = %d, npoint = %d): Fo = %g\n", __func__, ndim, npoint, Fo/points.size()); + if constexpr (is_abs) { + std::vector<int> P(npoint); + for (int idim = 0; idim < ndim; ++idim) { + for (int ip = 0; ip < npoint; ++ip) P[ip] = points[ip*ndim+idim]; + std::sort(P.begin(), P.end()); + if (ndim == 8 && ncluster == 6561) { + mid[2*idim + 0] = P[npoint/3]; + mid[2*idim + 1] = P[2*npoint/3]; + } else { + mid[idim] = npoint%2 == 0 ? 0.5f*(P[npoint/2] + P[npoint/2-1]) : P[npoint/2]; + if (kVerbose) printf("%s: mid[%d] = %g\n", __func__, idim, mid[idim]); + } + } + } else { + for (int k = 0; k < ndim; ++k) mid[k] = 0.5f*(range[k].first + range[k].second); + } + std::vector<float> sump(ncluster*ndim); + std::vector<int> counts(ncluster); + std::vector<float> result(ncluster*ndim); + if (ndim == 8 && (ncluster == 256 || ncluster == 6561)) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + uint16_t u = 0; + if (ncluster == 256) { + for (int k = 0; k < ndim; ++k) if (vp[k] > mid[k]) u |= (1 << k); + } else { + int s = 1; + for (int k = 0; k < ndim; ++k) { + int bin = vp[k] < mid[2*k+0] ? 0 : vp[k] < mid[2*k+1] ? 1 : 2; + u += s*bin; s *= 3; + } + } + ++counts[u]; + for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k]; + } + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) { + printf("%s: Oops. Cluster %d has no points\n", __func__, ic); + GGML_ABORT("fatal error"); + } + for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic]; + } + return result; + } + else if (ndim == 4 && (ncluster == 256 || ncluster == 625)) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + uint16_t u = 0; + if (ncluster == 256) { + for (int k = 0; k < ndim; ++k) u |= (bin4(vp[k]) << 2*k); + } else { + int s = 1; + for (int k = 0; k < ndim; ++k) { u += s*bin5(vp[k]); s *= 5; } + } + if (u >= int(counts.size())) { + printf("Oops: u = %u, vp = %g, %g, %g, %g\n", u, vp[0], vp[1], vp[2], vp[3]); + u = 0; + if (ncluster == 256) { + for (int k = 0; k < ndim; ++k) { + auto bin = bin4(vp[k]); u |= (bin << 2*k); + printf(" bin[%d] = %d, u = %u", k, bin, u); + } + } else { + for (int k = 0; k < ndim; ++k) printf(" bin[%d] = %d", k, bin5(vp[k])); + } + printf("\n"); + GGML_ABORT("fatal error"); + } + ++counts[u]; + for (int k = 0; k < ndim; ++k) sump[ndim*u + k] += vp[k]; + } + int nzero = 0; + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) { + ++nzero; + printf("%s: Oops. Cluster %d has no points: ", __func__, ic); + for (int k = 0; k < ndim; ++k) { + int l = (ic >> 2*k) & 3; + printf(" %d", l); + } + printf("\n"); + } else { + for (int k = 0; k < ndim; ++k) result[ic*ndim + k] = sump[ic*ndim + k]/counts[ic]; + } + } + if (nzero > 0) printf("%s: %d out of %d clusters dir not have any points\n", __func__, nzero, ncluster); + return result; + } + std::mt19937 rndm(1234); + float scale = 1.f/4294967296.f; + for (int i = 0; i < ncluster; ++i) { + auto v = result.data() + i*ndim; + for (int k = 0; k < ndim; ++k) v[k] = range[k].first + (range[k].second - range[k].first)*scale*rndm(); + } + std::vector<int> which_cluster(npoint, -1); + double Flast = Fo; + for (int iter = 0; iter < niter; ++iter) { + std::memset(sump.data(), 0, sump.size()*sizeof(float)); + std::memset(counts.data(), 0, counts.size()*sizeof(int)); + int nchanged = 0; + double F = 0; + for (int ip = 0; ip < npoint; ++ip) { + auto vp = points.data() + ndim*ip; + float best = INFINITY; int ibest = -1; + for (int ic = 0; ic < ncluster; ++ic) { + auto vc = result.data() + ndim*ic; + float dist2 = 0; + for (int k = 0; k < ndim; ++k) { + float d = vp[k] - vc[k]; dist2 += d*d; + } + if (dist2 < best) { + best = dist2; ibest = ic; + } + } + if (ibest < 0) { + printf("Oops(iteration %d) - failed to find cluster for point", iter); + for (int k = 0; k < ndim; ++k) printf(" %g", vp[k]); + printf("\nHave %d clusters\n", ncluster); + } + GGML_ASSERT(ibest >= 0); + F += best; + if (which_cluster[ip] != ibest) ++nchanged; + which_cluster[ip] = ibest; + ++counts[ibest]; + auto vc = sump.data() + ndim*ibest; + for (int k = 0; k < ndim; ++k) vc[k] += vp[k]; + } + if (nchanged == 0) break; + for (int ic = 0; ic < ncluster; ++ic) { + float norm = counts[ic] > 0 ? 1.f/counts[ic] : 0.f; + auto vc = sump.data() + ndim*ic; + auto r = result.data() + ndim*ic; + for (int k = 0; k < ndim; ++k) r[k] = vc[k]*norm; + } + if (kVerbose) printf("%s(iteration %d): F = %g, nchanged = %d\n", __func__, iter+1, F/points.size(), nchanged); + if (iter > 1 && Flast/F - 1 < 1e-6) break; + Flast = F; + } + int nzero = 0; + for (int ic = 0; ic < ncluster; ++ic) { + if (!counts[ic]) ++nzero; + } + if (nzero > 0) printf("%s: there are %d empty clusters\n", __func__, nzero); + return result; +} + +// ========================================== iq2_kt ==================================================== + +using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>; + +const QuantizerIQ2KT& iq2kt_quantizer() { + static std::mutex mutex; + static std::unique_ptr<QuantizerIQ2KT> quantizer; + std::lock_guard<std::mutex> lock(mutex); + if (!quantizer) quantizer = std::make_unique<QuantizerIQ2KT>(256, 8); + return *quantizer; +} + +void quantize_row_iq2_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights, + float * qtmp) { + + constexpr float kSigmaScale = 2.0f; + using Q = QuantizerIQ2KT; + + static_assert(Q::kNumVal%8 == 0); + + float * dptr = (float *)vy; + + block_iq2_kt * y = (block_iq2_kt *)(dptr + 1); + + int best_idx[2*Q::kNg]; + + auto& quantizer = iq2kt_quantizer(); + + int nblock = n_per_row / Q::kSuperBlockSize; + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_scale = 0, max_scale = 0; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq2_kt)); + + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + float ax = std::abs(xb[j]); + amax = std::max(amax, ax); + } + quantizer.find_best_match( amax/96.f, xb, weight, best_idx); + auto [dp, score_p] = quantizer.find_best_scale(xb, weight, best_idx); + quantizer.find_best_match(-amax/96.f, xb, weight, best_idx + Q::kNg); + auto [dm, score_m] = quantizer.find_best_scale(xb, weight, best_idx + Q::kNg); + + auto idx = best_idx; + if (score_p > score_m) scales[ib] = dp; + else { + scales[ib] = dm; idx += Q::kNg; + } + auto qt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q = quantizer.values() + idx[ig]*Q::kGroupSize; + for (int j = 0; j < Q::kGroupSize; ++j) qt[j] = q[j]; + qt += Q::kGroupSize; + } + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + if (!max_scale) { + *dptr = 0; + return; + } + + float d = max_scale/iq4k_values[0]; + float best = 0; + for (int itry = -9; itry <= 9; ++itry) { + float id = (itry + iq4k_values[0])/max_scale; + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xb = x + ibl*Q::kSuperBlockSize; + const float * qb = qtmp + ibl*Q::kSuperBlockSize; + const float * wb = all_weights + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = best_index_iq4nl(iq4k_values, id*scales[ib]); + float dl = iq4k_values[ls]; + for (int j = 0; j < Q::kBlockSize; ++j) { + float q = dl*qb[j]; + sumqx += wb[j]*xb[j]*q; + sumq2 += wb[j]*q*q; + } + xb += Q::kBlockSize; + wb += Q::kBlockSize; + qb += Q::kBlockSize; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; + } + } + + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]); + int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]); + y[ibl].scales[ib] = ls1 | (ls2 << 4); + } + } + + *dptr = d; + if (!d) return; + + for (int iloop = 0; iloop < 1; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + auto qs = (uint16_t *)y[ibl].ql; + const float * xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + int ls = iq4k_values[(y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf]; + float dl = d*ls; + quantizer.find_best_match(dl, xb, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + qs[j] = best_idx[j]; + auto xl = xb + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + qs += Q::kNg; + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d; + if (!d) return; + } else { + break; + } + + if (false) { + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + auto qs = (uint16_t *)y[ibl].ql; + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int j = 0; j < Q::kNg; ++j) best_idx[j] = qs[ib*Q::kNg+j]; + auto pair = quantizer.find_best_scale(xb, weight, best_idx); + scales[ib] = pair.first; + } + } + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + int ls1 = best_index_iq4nl(iq4k_values, id*scales[ib]); + int ls2 = best_index_iq4nl(iq4k_values, id*scales[ib + Q::kNblock/2]); + y[ibl].scales[ib] = ls1 | (ls2 << 4); + } + } + } + + } + +} +} + +void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq2_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq2_kt * y = (block_iq2_kt *)vy; + quantize_row_iq2_kt_ref(x, y, k); +} + +size_t quantize_iq2_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row); + std::vector<float> scales(n_per_row/QuantizerIQ2KT::kBlockSize); + std::vector<float> weights(n_per_row); + std::vector<float> xtmp(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq2_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq2_kt(const block_iq2_kt * x, float * y, int64_t k) { + assert(k % QuantizerIQ2KT::kSuperBlockSize == 0); + const int nb = k / QuantizerIQ2KT::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = *dptr * QuantizerIQ2KT::kScale; + x = (const block_iq2_kt *)(dptr + 1); + auto& deq = iq2kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + auto yl = y + ibl*QuantizerIQ2KT::kSuperBlockSize; + auto yh = yl + QuantizerIQ2KT::kSuperBlockSize/2; + const uint16_t * ql = (const uint16_t *)x[ibl].ql; + const uint16_t * qh = ql + QuantizerIQ2KT::kNg*QuantizerIQ2KT::kNblock/2; + for (int ib = 0; ib < QuantizerIQ2KT::kNblock/2; ++ib) { + float sl = d * iq4k_values[x[ibl].scales[ib] & 0xf]; + float sh = d * iq4k_values[x[ibl].scales[ib] >> 4]; + for (int ig = 0; ig < QuantizerIQ2KT::kNg; ++ig) { + deq.set_values(ql[ig], yl, sl); + deq.set_values(qh[ig], yh, sh); + yl += QuantizerIQ2KT::kGroupSize; + yh += QuantizerIQ2KT::kGroupSize; + } + ql += QuantizerIQ2KT::kNg; + qh += QuantizerIQ2KT::kNg; + } + } +} + +void vec_dot_iq2_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} + +namespace { + +using QuantizerIQ3KT = QuantizerIQKT<32, 8, 16, true>; +const QuantizerIQ3KT& iq3kt_quantizer() { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + static std::unique_ptr<QuantizerIQ3KT> quantizer; + if (!quantizer) quantizer = std::make_unique<QuantizerIQ3KT>(256, 8); + return *quantizer; +} + +void quantize_row_iq3_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, + float * all_weights, float * qtmp) { + + constexpr float kSigmaScale = 2.0f; + constexpr float kStep = 8.0f; + + using Q = QuantizerIQ3KT; + + static_assert(Q::kNumVal%8 == 0); + + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + + float * dptr = (float *)vy; + + block_iq3_kt * y = (block_iq3_kt *)(dptr + 1); + + int best_idx[2*Q::kNg]; + + auto& quantizer = iq3kt_quantizer(); + + int nblock = n_per_row / Q::kSuperBlockSize; + + float amax_row = 0; + for (int j = 0; j < n_per_row; ++j) amax_row = std::max(amax_row, std::abs(x[j])); + if (!amax_row) { + *dptr = 0.f; + std::memset(y, 0, nblock*sizeof(block_iq3_kt)); + return; + } + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_scale = 0, max_scale = 0; + + float xaux[Q::kBlockSize]; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq3_kt)); + + auto scales = all_scales + ibl*Q::kNblock; + auto xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + float ax = std::abs(xb[j]); + xaux[j] = ax; + amax = std::max(amax, ax); + } + scales[ib] = 0; + if (!amax) continue; + + //quantizer.find_best_match(amax/96.f, xaux, weight, best_idx+Q::kNg); + //scales[ib] = quantizer.find_best_scale(xaux, weight, best_idx+Q::kNg).first; + + float scale_0 = std::max(84.f, 123.f*amax/amax_row); + //float scale_0 = std::max(64.f, 123.f*amax/amax_row); + float best = 0; + for (int itry = -3; itry <= 3; ++itry) { + quantizer.find_best_match(amax/(scale_0 + kStep*itry), xaux, weight, best_idx); + auto [d, score] = quantizer.find_best_scale(xaux, weight, best_idx); + if (score > best) { + best = score; + scales[ib] = d; + std::memcpy(best_idx+Q::kNg, best_idx, Q::kNg*sizeof(int)); + } + } + + auto xt = qtmp + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int ig = 0; ig < Q::kNg; ++ig) { + auto q = quantizer.values() + Q::kGroupSize*best_idx[Q::kNg+ig]; + for (int j = 0; j < Q::kGroupSize; ++j) *xt++ = q[j]; + } + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + GGML_ASSERT(max_scale >= 0); + float d = max_scale/15; + float best = 0; + for (int itry = -9; itry <= 9; ++itry) { + float id = (itry*0.2f + 15)/max_scale; + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * xb = x + ibl*Q::kSuperBlockSize; + const float * qb = qtmp + ibl*Q::kSuperBlockSize; + const float * wb = all_weights + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int ls = nearest_int(id*scales[ib]); + ls = std::max(0, std::min(15, ls)); + float dl = ls; + for (int j = 0; j < Q::kBlockSize; ++j) { + float q = dl*qb[j]; + sumqx += wb[j]*std::abs(xb[j])*q; + sumq2 += wb[j]*q*q; + } + xb += Q::kBlockSize; + wb += Q::kBlockSize; + qb += Q::kBlockSize; + } + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d*sumqx; + } + } + + float id = d ? 1/d : 0.f; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto scales = all_scales + ibl*Q::kNblock; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + int ls1 = nearest_int(id*scales[ib]); + int ls2 = nearest_int(id*scales[ib + Q::kNblock/2]); + ls1 = std::max(0, std::min(15, ls1)); + ls2 = std::max(0, std::min(15, ls2)); + y[ibl].scales[ib] = ls1 | (ls2 << 4); + } + } + + *dptr = d; + + for (int iloop = 0; iloop < 1; ++iloop) { + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + uint16_t * ql = (uint16_t *)y[ibl].ql; + + std::memset(y[ibl].qh, 0, kNumGroups/2); + const float * xbl = x + ibl*Q::kSuperBlockSize; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * xb = xbl + Q::kBlockSize*ib; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int j = 0; j < Q::kBlockSize; ++j) { + xaux[j] = std::abs(xb[j]); + if (xb[j] < 0) y[ibl].qh[j] |= (1 << ib); + } + int ls = (y[ibl].scales[ib%(Q::kNblock/2)] >> 4*(ib/(Q::kNblock/2))) & 0xf; + float dl = d*ls; + quantizer.find_best_match(dl, xaux, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + ql[ib*Q::kNg+j] = best_idx[j]; + auto xl = xaux + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + best_idx[j]*Q::kGroupSize; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + *dptr = d; + if (!d) break; + } else { + break; + } + } +} +} + +void quantize_row_iq3_kt_ref(const float * x, block_iq3_kt * y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq3_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq3_kt(const float * x, void * vy, int64_t k) { + assert(k % QK_K == 0); + block_iq3_kt * y = (block_iq3_kt *)vy; + quantize_row_iq3_kt_ref(x, y, k); +} + +size_t quantize_iq3_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row); + std::vector<float> scales(n_per_row/QuantizerIQ3KT::kBlockSize); + std::vector<float> weights(n_per_row), xtmp(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq3_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data(), xtmp.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq3_kt(const block_iq3_kt * x, float * y, int64_t k) { + using Q = QuantizerIQ3KT; + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + assert(k % Q::kSuperBlockSize == 0); + const int nb = k / Q::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = *dptr * Q::kScale; + x = (const block_iq3_kt *)(dptr + 1); + auto& deq = iq3kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + auto yl = y + ibl*Q::kSuperBlockSize; + auto yh = yl + Q::kSuperBlockSize/2; + auto qll = (const uint16_t *)x[ibl].ql; + auto qlh = qll + kNumGroups/2; + int jj = 0; + for (int ib = 0; ib < Q::kNblock/2; ++ib) { + float sl = d * (x[ibl].scales[ib] & 0xf); + float sh = d * (x[ibl].scales[ib] >> 4); + uint8_t l_mask = 1 << ib; + uint8_t h_mask = l_mask << (Q::kNblock/2); + for (int ig = 0; ig < Q::kNg; ++ig) { + deq.set_values(qll[jj], yl, sl); + deq.set_values(qlh[jj], yh, sh); + for (int j = 0; j < Q::kGroupSize; ++j) { + if (x[ibl].qh[ig*Q::kGroupSize+j] & l_mask) yl[j] = -yl[j]; + if (x[ibl].qh[ig*Q::kGroupSize+j] & h_mask) yh[j] = -yh[j]; + } + yl += Q::kGroupSize; + yh += Q::kGroupSize; + ++jj; + } + } + } +} + +void vec_dot_iq3_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} + +// ======================================== iq4_kt + +namespace{ + +using QuantizerIQ4KT = QuantizerIQKT<32, 4, 15>; + +const QuantizerIQ4KT& iq4kt_quantizer(bool with_offset = false) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + static std::unique_ptr<QuantizerIQ4KT> quantizer1; + static std::unique_ptr<QuantizerIQ4KT> quantizer2; + if (with_offset) { + if (!quantizer2) quantizer2 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096+32768); + return *quantizer2; + } + if (!quantizer1) quantizer1 = std::make_unique<QuantizerIQ4KT>(625, 6, 4096); + return *quantizer1; +} + +void quantize_row_iq4_kt_impl(const float * x, void * vy, int n_per_row, const float * quant_weights, float * all_scales, float * all_weights) { + + constexpr float kSigmaScale = 2.0f; + constexpr int kNtry = 2; + using Q = QuantizerIQ4KT; + + static_assert(Q::kNumVal%8 == 0); + + float * dptr = (float *)vy; + + block_iq4_kt * y = (block_iq4_kt *)(dptr + 2); + + auto& quantizer1 = iq4kt_quantizer(); + auto& quantizer2 = iq4kt_quantizer(true); + + int nblock = n_per_row / Q::kSuperBlockSize; + + Q::set_weights(kSigmaScale, nblock, x, quant_weights, all_weights); + + float amax_row = 0, row_av = 0; + for (int j = 0; j < n_per_row; ++j) { + row_av += x[j]; + amax_row = std::max(amax_row, std::abs(x[j])); + } + row_av /= n_per_row; + dptr[1] = row_av; + if (!amax_row) { + dptr[0] = 0.f; + std::memset(y, 0, nblock*sizeof(block_iq4_kt)); + return; + } + + int best_idx[2*Q::kNg]; + float xaux[Q::kBlockSize]; + + float amax_scale = 0, max_scale = 0; + + for (int ibl = 0; ibl < nblock; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq4_kt)); + + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + float amax = 0; + for (int j = 0; j < Q::kBlockSize; ++j) { + xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + float ax = std::abs(xaux[j]); + amax = std::max(amax, ax); + } + if (!amax) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale_0 = std::max(92.f, 127.f*amax/amax_row); + for (int itry = -kNtry; itry <= kNtry; ++itry) { + quantizer1.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dp, score_p] = quantizer1.find_best_scale(xaux, weight, best_idx); + if (score_p > best) { + best = score_p; scales[ib] = dp; + } + quantizer1.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dm, score_m] = quantizer1.find_best_scale(xaux, weight, best_idx); + if (score_m > best) { + best = score_m; scales[ib] = dm; + } + } + + quantizer2.find_best_match(scales[ib], xaux, weight, best_idx); + auto [d, score] = quantizer2.find_best_scale(xaux, weight, best_idx); + if (score > best) { + scales[ib] = d; + y[ibl].qs[ib] = 1; + } + bool with_offset = false; + for (int itry = -kNtry; itry <= kNtry; ++itry) { + quantizer2.find_best_match( amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dp, score_p] = quantizer2.find_best_scale(xaux, weight, best_idx); + if (score_p > best) { + best = score_p; scales[ib] = dp; with_offset = true; + } + quantizer2.find_best_match(-amax/(8.f*itry + scale_0), xaux, weight, best_idx); + auto [dm, score_m] = quantizer2.find_best_scale(xaux, weight, best_idx); + if (score_m > best) { + best = score_m; scales[ib] = dm; with_offset = true; + } + } + if (with_offset) y[ibl].qs[ib] = 1; + + float abs_scale = std::abs(scales[ib]); + if (abs_scale > amax_scale) { + amax_scale = abs_scale; + max_scale = scales[ib]; + } + } + + } + + float d = -max_scale/64; + + dptr[0] = d; + if (!d) return; + + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + + for (int iloop = 0; iloop < 1; ++iloop) { + + const float id = 1/d; + + float sumqx = 0, sumq2 = 0; + for (int ibl = 0; ibl < nblock; ++ibl) { + + // high 3 bits + scales + // each block of 32 needs 8 x 3 (high bits) + 1 x 8 (scale) = 32 bits = 1 x uint32_t + // we have 8 blocks + auto shb = y[ibl].qs; // high 3 bits + scales + auto ql = (uint8_t *)(shb + Q::kNblock); + auto qh = ql + kNumGroups; + std::memset(qh, 0, kNumGroups/2); + const float * xbl = x + ibl*Q::kSuperBlockSize; + auto scales = all_scales + ibl*Q::kNblock; + + for (int ib = 0; ib < Q::kNblock; ++ib) { + auto& quantizer = y[ibl].qs[ib] & 1 ? quantizer2 : quantizer1; + const float * weight = all_weights + ibl*Q::kSuperBlockSize + ib*Q::kBlockSize; + for (int j = 0; j < Q::kBlockSize; ++j) xaux[j] = xbl[ib*Q::kBlockSize+j] - row_av; + int ls = nearest_int(id*scales[ib]); + ls = std::min(ls, 63); + *(uint8_t *)(shb + ib) = ((ls + 64) << 1) | (shb[ib] & 1); + float dl = d*ls; + quantizer.find_best_match(dl, xaux, weight, best_idx); + + for (int j = 0; j < Q::kNg; ++j) { + shb[ib] |= ((best_idx[j] >> 12) << (8 + 3*j)); + ql[Q::kNg*ib + j] = best_idx[j] & 255; + qh[(Q::kNg*ib + j)%(kNumGroups/2)] |= ((best_idx[j] >> 8) & 0xf) << 4*((Q::kNg*ib + j)/(kNumGroups/2)); + auto xl = xaux + Q::kGroupSize*j; + auto wl = weight + Q::kGroupSize*j; + auto ql = quantizer.values() + Q::kGroupSize*best_idx[j]; + for (int k = 0; k < Q::kGroupSize; ++k) { + float q = ql[k]*ls; + sumqx += wl[k]*xl[k]*q; + sumq2 += wl[k]*q*q; + } + } + } + } + if (sumq2 > 0) { + d = sumqx/sumq2; + dptr[0] = d; + if (!d) break; + } else { + break; + } + } +} +} + +void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_kt(x, (void *)y, 1, k, nullptr); +} + +void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_iq4_kt * y = (block_iq4_kt *)vy; + quantize_row_iq4_kt_ref(x, y, k); +} + +size_t quantize_iq4_kt(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row); + std::vector<float> scales(n_per_row/QuantizerIQ4KT::kBlockSize); + std::vector<float> weights(n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrows; ++row) { + quantize_row_iq4_kt_impl(src, (void *)qrow, n_per_row, imatrix, scales.data(), weights.data()); + src += n_per_row; + qrow += row_size; + } + return nrows * row_size; +} + +void dequantize_row_iq4_kt(const block_iq4_kt * x, float * y, int64_t k) { + using Q = QuantizerIQ4KT; + assert(k % Q::kSuperBlockSize == 0); + constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize; + const int nb = k / Q::kSuperBlockSize; + const float * dptr = (const float *)x; + const float d = dptr[0] * Q::kScale; + const float row_av = dptr[1]; + x = (const block_iq4_kt *)(dptr + 2); + auto& deq = iq4kt_quantizer(); + for (int ibl = 0; ibl < nb; ++ibl) { + auto shb = x[ibl].qs; + auto ql = (const uint8_t *)(shb + Q::kNblock); + auto qh = ql + kNumGroups; + for (int ib = 0; ib < Q::kNblock; ++ib) { + int offset = shb[ib] & 1 ? 32768 + 4096 : 4096; + //auto& deq = shb[ib] & 1 ? deq2 : deq1; + int ls = int((shb[ib] & 0xff) >> 1) - 64; + float sl = d * ls; + for (int ig = 0; ig < Q::kNg; ++ig) { + int jj = ib*Q::kNg+ig; + uint16_t idx = ql[jj] | ((qh[jj%(kNumGroups/2)] << (8 - 4*(jj/(kNumGroups/2)))) & 0xf00) | (((shb[ib] >> (8 + 3*ig)) & 7) << 12); + deq.set_values(idx, y, sl, offset); + for (int j = 0; j < Q::kGroupSize; ++j) y[j] += row_av; + y += Q::kGroupSize; + } + } + } +} + +void vec_dot_iq4_kt_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KT, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + +} diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 9c274d4b..70918a65 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -67,6 +67,24 @@ size_t quantize_iq2_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq2_ks_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq2_kt_ref(const float * GGML_RESTRICT x, block_iq2_kt * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_kt(const block_iq2_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq3_kt_ref(const float * GGML_RESTRICT x, block_iq3_kt * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq3_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq3_kt(const block_iq3_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq3_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq4_kt_ref(const float * GGML_RESTRICT x, block_iq4_kt * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_kt(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_kt(const block_iq4_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void quantize_row_iq5_ks_ref(const float * GGML_RESTRICT x, block_iq5_ks * GGML_RESTRICT y, int64_t k); void quantize_row_iq5_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); size_t quantize_iq5_ks(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/include/llama.h b/include/llama.h index b6b408de..607a590d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -194,6 +194,9 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_KV = 149, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_KS = 150, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_KT = 151, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_KT = 152, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_KT = 153, // except 1d tensors // LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors diff --git a/src/llama.cpp b/src/llama.cpp index 9d9c7c4e..48d7214d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4355,6 +4355,9 @@ struct llama_model_loader { case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; + case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break; + case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break; + case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ1_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_S_R4;break; case GGML_TYPE_IQ1_M_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ1_M_R4;break; @@ -5095,6 +5098,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ2_M_R4: return "IQ2_M_R4 - 2.7 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_KT: return "IQ2_KT - 2.125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_KT: return "IQ3_KT - 3.125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_KT: return "IQ4_KT - 4.0 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: return "IQ3_XXS_R4 - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S_R4: return "IQ1_S_R4 - 1.5 bpw"; @@ -18787,10 +18793,11 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { @@ -18818,7 +18825,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4) { new_type = GGML_TYPE_IQ3_S; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { new_type = GGML_TYPE_IQ3_S; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { @@ -18863,6 +18870,42 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (name.find("attn_output.weight") != std::string::npos) { new_type = qs.model.hparams.n_expert >= 4 ? GGML_TYPE_Q5_K_R4 : GGML_TYPE_IQ2_K_R4; } + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT) { + if (name.find("attn_v.weight") != std::string::npos) { + if (qs.model.hparams.n_expert >= 4 || qs.model.hparams.n_gqa() >= 4) new_type = GGML_TYPE_IQ4_K; + else if (qs.model.hparams.n_gqa() >= 2) new_type = GGML_TYPE_IQ3_K; + else new_type = GGML_TYPE_Q2_K; + ++qs.i_attention_wv; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (qs.model.hparams.n_expert >= 8 && (name.find("blk.0.ffn_down") != std::string::npos || + name.find("blk.0.ffn_gate") != std::string::npos || + name.find("blk.0.ffn_up") != std::string::npos)) { + new_type = GGML_TYPE_IQ3_K; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + new_type = GGML_TYPE_IQ3_K; + } + else if (name.find("_shexp.weight") != std::string::npos) { + new_type = GGML_TYPE_IQ4_K; + } + else if (name.find("ffn_down") != std::string::npos) { + auto [i_layer, n_layer] = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); + if (qs.params->ffn_down_type < GGML_TYPE_COUNT) new_type = qs.params->ffn_down_type; + else if (i_layer < n_layer/8) { + new_type = GGML_TYPE_IQ3_K; + } + ++qs.i_ffn_down; + } + else if (name.find("attn_output.weight") != std::string::npos) { + new_type = qs.model.hparams.n_expert >= 4 ? GGML_TYPE_Q5_K : GGML_TYPE_IQ3_K; + } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || @@ -18919,6 +18962,16 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { + //new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K + // : !qs.has_imatrix ? GGML_TYPE_IQ3_K : GGML_TYPE_IQ3_KT; + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ4_K : GGML_TYPE_IQ3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KT) { + //new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ5_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ4_K + // : !qs.has_imatrix ? GGML_TYPE_IQ4_KS : GGML_TYPE_IQ4_KT; + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ5_K : GGML_TYPE_IQ4_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K_R4 : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K_R4 : !qs.has_imatrix ? GGML_TYPE_IQ3_K_R4 : GGML_TYPE_IQ3_XXS_R4; @@ -19046,6 +19099,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) { new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT && !qs.has_imatrix) { + new_type = i_layer < n_layer/8 ? GGML_TYPE_IQ4_K : GGML_TYPE_IQ3_K; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 && !qs.has_imatrix) { new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K_R4 : GGML_TYPE_IQ3_K_R4; } @@ -19110,7 +19166,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ5_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ5_KS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 || + ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT || ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4|| ftype == LLAMA_FTYPE_MOSTLY_IQ4_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S_R4) { new_type = GGML_TYPE_Q5_K; // should the IQ_K quants be applied here as the new type for the IQ_K ftypes ? @@ -19119,6 +19176,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; // This list could be generalized and streamlined else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT && qs.model.hparams.n_gqa() >= 4) new_type = GGML_TYPE_IQ3_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; @@ -19321,10 +19379,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS_R4:default_type = GGML_TYPE_IQ2_XS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_KS: default_type = GGML_TYPE_IQ2_KS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_KT: default_type = GGML_TYPE_IQ2_KT; break; case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; case LLAMA_FTYPE_MOSTLY_IQ2_M_R4:default_type = GGML_TYPE_IQ2_S_R4;break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ3_KT: default_type = GGML_TYPE_IQ3_KT; break; + case LLAMA_FTYPE_MOSTLY_IQ4_KT: default_type = GGML_TYPE_IQ4_KT; break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: default_type = GGML_TYPE_IQ3_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ1_S_R4:default_type = GGML_TYPE_IQ1_S_R4;break; |