summaryrefslogtreecommitdiff
path: root/src
AgeCommit message (Collapse)Author
2025-05-22Streamline a bit the quant strategies (#443)Nexes the Elder
* Streamline a bit the quant strategies No change over the existing patterns, except for the bump for attn_k and attn_v for the models with 4 and 6 experts (several frankensteins seen on HF, and which also use GQA). The rest is applying the existing patterns to the new IQ_K quants. Also, a Q8_0 for attn_q slipped into the MOEs 8 experts rule, I removed it, because that tensor is much bigger than attn_k or attn_v. * remove <=8 experts condition.
2025-05-17IQ5_KS_R4: row-interleaved IQ5_KS (#426)Kawrakow
* iq5_ks_r4: basics * iq5_ks_r4: Zen4 works * iq5_ks_r4: AVX2 works * iq5_ks_r4: NEON * Fix iq5_ks on NEON --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-15Adding IQ5_KS - 5.25 bpw quants (#422)Kawrakow
* iq5_ks: basics * iq5_ks: quantize * iq5_ks: CUDA dequantize works * iq5_ks: dot product works on CUDA * iq5_ks: MMQ works * iq5_ks: Zen4 * iq5_ks: AVX2 But is is not quite right, just like iq4_k, iq5_k, iq6_k, iq4_ks. All these need fixing on AVX2. * iq5_ks: NEON * iq5_ks: Metal dequantize * iq5_ks: Metal dot product --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-12Enable faster prompt processing with mainline llama.cpp GGUFs (#409)Kawrakow
* Enable MLA-3 in crippled GGUFs: WIP * Enable MLA-3 in crippled GGUFs: seems to work * Add newly created tensors to model.tensors_by_name Else they don't get run-time repacked. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-12Faster DeepSeek FA on CUDA (#408)Kawrakow
* New DeepSeek FlashMLA Does not work because the RoPE portion is stored at the end in our case, while in mainline it is stored at the beginning, and the FA kernel assumes that. * Rearrange MLA K cache so it first new CUDA FA implementation * constexpr and minor changes --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-12GPU offload policy (#405)Kawrakow
* Adding GPU offload policy * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-09Handle incompatible DeepSeek GGUFs (#394)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-09Support for Llama-3-Nemotron models (#377)saood06
* conflict resolution * Changes to make work and add longrope support * Changes to n_attention_wv rule * Untested support of 253B * DeciLMCausalModel now reads rope_theta from config.json properly * Remove errant Granite mentions * Better n_attention_vw rule * Update vocab.py --------- Co-authored-by: Yee Man Chan <ymchan@gmail.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-05-02Fix model architecture name (#366)saood06
Co-authored-by: junhuihe <junhui-he@outlook.com>
2025-04-29Apply Qwen3 PR from llama.cpp (#355)Ben Harris
2025-04-26Add GLM-4-0414 Model Support (#344)ubergarm
* Add GLM-4-0414 model support Based on zRzRzRzRzRzRzR's PR on mainline llama.cpp. Still some issues where it doesn't work: * offloading >=60 layers to GPU * no flash attention * Remove seemingly unused llm_tensor enums Both of these seem unused and LLM_TENSOR_ATTN_POST_NORM already existed which seems pretty similar? Don't think they were used in the python code either... So removed these as possibly just cruft: * LLM_TENSOR_POST_ATTN_NORM * LLM_TENSOR_POST_MLP_NORM * Set flash attention precision to f32 on GLM4 arch * Set non flash attention precision to f32 on GLM4 * Remove reshape_3d() for Vcur in build_glm4() This fixes the non-flash-attention inferencing on both CPU and CUDA.
2025-04-26Add support for Cohere2 (#341)Kawrakow
* Add support for Cohere2 * Fixe IQ4_NL on AVX2 * Command-A needs fp32 precision for K*Q --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-25Fix LLaMA-4 attention (#342)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-22BitNet adjustments (#338)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-22Add support for bitnet2b_2501 model (#337)saood06
* add support for bitnet2b_2501 model * Fixes * Support both model names --------- Co-authored-by: potassiummmm <zhou.hansong@outlook.com>
2025-04-11Correct L4 rms_norm (#324)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-10LlaMA-4 support (text only) (#321)Kawrakow
* llama4: WIP * llama4: this seems to be working --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-08Guard against attempts to use MLA for non-MLA models (#320)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-07Add copyright notices (#317)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-04-01Additional guards for interleaved quants (#299)Kawrakow
* Make sure no interleaved quants are being used for token embeddings also with `--pure` and/or `--custom-q`. * Simplify --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-27Make sure tensor row size is multiple of block size also when quantizing ↵Kawrakow
with --pure (#294) * WIP - not working * q8_0 without bells and wistles works * It works for q8_0 * Use bf16 instead of f16,int16 * q4_0_r8 * q5_0_r4 * q6_0_r4 * Also q4_1 and q5_1 * Add check if selected type is possible with --pure I often want to quantize with --pure to see quantization performance without quantization mixes. But for models where there qre tensors with row sizes that are not multiple of 256, this results in a crash for k- and i-quants. Hence, lets add a check if the quant selected via --pure is applicable, and change it if not. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-23Improve DeepSeek batched processing speed (#282)Kawrakow
* Improve DeepSeek batched processing speed * Revert the commented out section in iqk_mul_mat.cpp It does have some benefit at long contexts. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-23Test transparent huge pages on Linux (#278)Kawrakow
* Adding ability to use THP on Linux * Use the actual page size4 used for mmap also in munmap * Add -thp to llama-bench --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-22Add Gemma3 support (text only) (#276)Kawrakow
* WIP Gemma3: not working * gemma3: build_gemma3 seems to be working now * Revert changes to convert_hf_to_gguf.py It wasn't working, so I guess, it is better to leave the conversion up tp upstream. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-21Fix bug: missing parentheses in logical expression (#275)Kawrakow
This results in GGGGGGGGGGGGG when generating with mla = 3, fa = 0. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-21Specify tensor name regex for tensors to be repacked (#274)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-21FlashMLA-3: the best of both worlds (CPU only) (#273)Kawrakow
* Repack a model with the quantize tool * WIP * Fixed various issues As we don't have a way to tell if a repacked quant has been modified, I had to remove the modification at the expense of a slight decrease in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and q4_0_r8 on ARM. * Create wk_b and wv_b as Q8_0_R8 if the wkv_b type is interleaved * Fix GCC 13.3 compilation error * Another one * Add missing include * FlashMLA-3: the best of both worlds - CPU only --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-21Convert models to row-interleaved quants using the quantize tool (#272)Kawrakow
* Repack a model with the quantize tool * WIP * Fixed various issues As we don't have a way to tell if a repacked quant has been modified, I had to remove the modification at the expense of a slight decrease in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and q4_0_r8 on ARM. * Create wk_b and wv_b as Q8_0_R8 if the wkv_b type is interleaved * Fix GCC 13.3 compilation error * Another one * Add missing include --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-19Honor mmap setting when using tensor overrides (#270)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-18Make Q8_0 KV cache work with mla=2,fa on CUDA (#264)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-18FlashMLA-2: reduce compute buffer size (CUDA and CPU) (#260)Kawrakow
* FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size * Prepare wk_b when loading DeepSeek models (if wk_b is missing) * Add some comments * Fix case where wkv_b is quantized with k- or i-quants. * Fix CUDA There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is not contiguous. So, for now, we also create wv_b during model loading and use that instead of the 3D view of wkv_b. * FlashMLA-2: avoid conversions to f32 also on CUDA * Be able to compute for more than 65535 tokens On CUDA just a quick hack that allows us to cancatenate tensors with more than 65535 rows along zroth dimension as needed by FlashMLA-2. Also needed some care in the perplexity tool to avoid int overflows when evaluating the computed logits. * Reduce memory usage for FlashMLA-2 Oh, also fix int overflow in the CUDA concat implementation. It is funny how the llama.cpp 64-bit police has gone (almost) everywhere and replaced 32-bit ints with 64-bit ints, needed or not, but hasn't done it where it is actually needed. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-17Prepare wk_b tensors of DeepSeek models on the fly (#259)Kawrakow
* FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size * Prepare wk_b when loading DeepSeek models (if wk_b is missing) * Add some comments * Fix case where wkv_b is quantized with k- or i-quants. * Fix CUDA There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is not contiguous. So, for now, we also create wv_b during model loading and use that instead of the 3D view of wkv_b. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-13FlashMLA-2 (CPU): faster and smaller compute buffer size (#253)Kawrakow
* FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-10DeepSeek imatrix stuff (#250)Kawrakow
* This gives us ~20% TG speedup for DeepSeek on CUDA * Slightly better * Also do it for plain (not fused) mul_mat_id * Guard against numerical precision issues for MLA on CUDA * imatrix: wv_b <-> wkv_b --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-10Faster MoE token generation on CUDA (#248)Kawrakow
* This gives us ~20% TG speedup for DeepSeek on CUDA * Slightly better * Also do it for plain (not fused) mul_mat_id * Guard against numerical precision issues for MLA on CUDA --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-09This works on CUDA, but (#247)Kawrakow
PP speed is great, almost on par with standard FA. But TG speed is pathetic. The strangest thing is that the slowdown is not due to FA, but due to the ffn_gate_exps gemm, which somehow becomes very slow. WTF? As I'm unable the resolve the slow ffn_gate_exps GEMM mystery, for now TG goes via mla=2, PP is via FA. Also discovered the ggml_cast op, so we don't need the aux tensors that I had added to the KV cache. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-08Faster FlashMLA prompt processing (#246)Kawrakow
* FlashMLA-2: faster prompt processing The current MLA implementation computes wv_b * (k_cache * softmax(k_cache * (wk_b*q))) This leads to 3.4X more multiply-adds (madds) compared to standard attention. Due to the resulting tensor shapes, TG is still faster than standard attention because the k_cache*(wk_b*q) and k_cache*(softmax(k_cache * (wk_b*q))) multiplications become GEMMs, so the additional madds are more than compensated for due to the much higher performance of GEMMs compared to GEMVs. But for PP, where we are dealing with GEMMs in both cases, the additional madds needed for MLA lead to lower performance, with the performance gap increasing with context length. So, then, when we are dealing with PP, we can rearrange the above to (wv_b * k_cache) * softmax( (wk_b^T*k_cache) * q), thus transforming it into the standard attention mechanism. We do need two additional matrix multiplications (which in practice is done as a single wkv_b * k_cache GEMM) with the *entire* K cache. But this is still cheaper than MLA, as we end up with 1.8X the madds required by standard attention. Oh, these figures are for the DeepSeek-V3/R1/Lite attention architecture. This leads to a significant PP performance increase compared to standard MLA with FA. There are many upsides to this: * If we only apply the above trick when we are processing more than X tokens (with suitable chosen X), TG performance stays the same as MLA with FA * We still need to store just the K-cache, so 576 entries per layer for DeepSeek-V3/R1/Lite * We get significantly better PP performance * We can use MLA+FA on CUDA. It works already with this commit for PP, something is not yet quite right for TG. The downside is that it only works with fp16 cache (for now). This is so because we need to convert the cache to fp32, else we cannot do the wkv_b * k_cache matrix multiplication (which in ggml requires the second operand to be fp32). But converting (copying) to fp32 only works for f16, bf16 and f32 tensors, so no luck with quantized cache. Another reason that we need to convert to fp32 is that the cache contains the RoPE'd portion, which we need to concatenate to the result of the wkv_b * k_cache matrix multiplication. Also this op works only when the tensors being concatenated are both fp32. So much about ggml being a general purpose ML library. * FlashMLA-2: on the CPU it now works for quantized cache except for q8_KV (q8_KV has row meta data, and there is still some confusion with row sizes because of that). * FlashMLA-2: on the CPU it now works also with q8_KV --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-07Custom quantization rules with regular expressions (#244)Kawrakow
* Custom quantization rules with regular expressions * Add the --custom-q option to the help --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-05DeepSeek CUDA Flash Attention (#241)Kawrakow
* WIP CUDA FA with Dk != Dv * WIP * CUDA FA WIP - It actually works! No TG yet, but for PP I can run FA with fp16 cache and it gets the same answer. * CUDA FA WIP - it now works for Q8_0 + Q8_0 for KV cache * CUDA FA WIP - TG, not working yet. * CUDA FA with Dk != Dv: it works now for DeepSeek --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-03Flash MLA (CPU only) (#240)Kawrakow
* FlashMLA - it finally works (on the CPU) * FlashMLA: allow for f16 and bf16 cache in addition to q8_0 * It works with ggml FA, not with iqk FA * WIP * FlashMLA: it now works with iqk I had forgotten to divide the Q stride by sizeof(float) and that's why, very cobfusingly, it was working for TG but not for PP. * WIP * FlashMLA: that should be it for now --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-02SER - Smart Expert Reduction (#239)Kawrakow
* A better way to measure the cost of ggml_barrier * Smart expert selection * Add ser option to llama-bench --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-03-01Reduce size of compute buffers (#237)Kawrakow
* This reduces compute buffer size for MLA * This should accomplish it for standard attention * Much better * Better concat for contiguous tensors If all the op does is to concatenate the second tensor to the first, why would we want to have a loop? --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-27Option to use MLA without a transposed cache (#235)Kawrakow
The `-mla` command line option turns into an int from a bool. mla = 0: use standard attention mla = 1: use MLA with transposed cache mla > 1: use MLA without transposed cache Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-27Faster MLA on CUDA (#234)Kawrakow
* Slight MLA TG performance improvement on CUDA The low MLA performance on CUDA is dues to the wk_b * q_nope operation. It turns into n_head matrix multiplications with n_head separate quantization and GEMV steps. The associated overhead is just too much for TG where each GEMV is very fast (512 x 128 = 131 KFLOP for DeepSeek-Lite, 4X that for DeepSeekV3/R1). The way it was done there was also a copy of each q_nope row before quantization, which I have now eliminated. This results in a ~2.5% speedup. What needs to happen instead is to launch a single computation that quantizes all heads, and then have a kernel that does the GEMV for all heads instead of n_head sequential GEMVs. * Slightly better * CUDA: Quantize non-contiguous tensors * Much better MLA It is a total hack, but it works. * Cleanup Remove duplicated gemv's. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-25Give the user the option to override where model weights are stored (#232)Kawrakow
* Give the user the option to override where model weights are stored * Fix ggml_nbytes() problem and cleanup For a tensor with zero elements ggml_nbytes() was returning uint64_t::max, and this was causing graph allocation failure. * Add timing info to CUDA graph evaluation * Add more timing info --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-23Fused MoE ffn_up and ffn_gate (#229)Kawrakow
* Fusing MoE up * unary(gate) * Fusing MoE up * unary(gate): CUDA We get ~13% speedup for PP-512 and ~2% for TG-128 for DeepSeek-Lite * On CUDA also fuse MoE down * (up * unary(gate)) in case the MUL_MAT_ID op for the down experts is the next op in the graph. * Command line option to enable fused MoE up*unary(gate) * Add fmoe option to llama-bench * Adding forgotten gelu, relu, silu on ARM --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-20Honor attn_output specified in the command line also for low-bit quantsIwan Kawrakow
2025-02-19Q8_KV: 8-bit quantization type targeting the KV cache (#208)Kawrakow
* Adding q8_KV - Basics + AVX2 gemm/gemv * q8_KV: Better AVX2 gemm * q8_KV: Better Zen4 gemm We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s. * q8_KV: AVX2 gemm/gemv We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr. * q8_KV: be able to use it for K cache This required quite a few fixes in ggml and llama.cpp: * ggml: do not calculate row size as n/block_size*type_size. I had removed most of it when implementing the quants with per row scale, bit it was stull lurking in ggml_copy. Not sure if these were the last remnants of ggmil-style row sizes, or if there are still places left * llama.cpp: get rid of the the 1d K cache assumption. Create and manage the K-cache as a 2D tensor so we can have per row meta data as needed by q8_KV. Using q8_KV for K-cache results in non-negligible performance gains. More details to follow, but for DeepSeek-Lite with MLA, we get 18% speedup for PP-8192 compared to q8_0 K-cache. * q8_KV: be able to use it for K cache in FA * q8_KV: repack it for K*Q in FA * q8_KV: slightly faster gemv on Zen4 * q8_KV: slightly faster gemv on Zen4 * q8_KV: ARM_NEON We get PP-512 = 167 t/s for L3-8B without interleaving! We do the interleaving on the fly, so I wonder if this could be done for other quants as well. * q8_KV: use it in FA on NEON * q8_KV_r8 - repacked q8_KV On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s) This makes no sense whatsoever as the q8_KV_r8 GEMM is basically the q8_k_r8 GEMM with the unnecessary block stuff removed (so, one would think that it would be faster). * q8_KV_r8: don't use nrc_y = 16 on Zen4 This is faster - 350 t/s. Why? Much better than the 290 t/s we had before, but still slower than the 370 t/s for q8_k_r8. * q8_KV: nrc_y = 16 also doesn't pay off in FA * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-13MLA: allow Q8_0 K-cache for MLA (#206)Kawrakow
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2025-02-13Faster MLA prompt processing (#205)Kawrakow
* Do not allocate / report caches that are not used It is either the standard KV cache or MLA cache, not both. * Rename X_pe to X_rope Much easier to follow, at least for my brain, when we have X_rope : rotational position encoding X_nope : no position encoding instead of X_pe and X_nope, where I was wondering wtf is 'pe' and 'nope'. * WIP * WIP * WIP * WIP * Warn user when disabling MLA * MLA: compile time option to not use transposed KV cache Cuts KV cache size in nearly half at the expense of slower TG performance for long contexts (it becomes similar to no-MLA). --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>