diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-07-24 16:49:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-24 16:49:00 +0200 |
commit | 47c1243e3c0cec135cc456b35ace51c37db7e8df (patch) | |
tree | 18f1fe14515df53169319101a7956f3244ef79ca | |
parent | 8fe7e04456d3bf3fd48e83bf0194f68abb6e80a1 (diff) |
Update README.md
Adding MoE and Bitnet performance tables
-rw-r--r-- | README.md | 93 |
1 files changed, 71 insertions, 22 deletions
@@ -5,14 +5,15 @@ ## TL;DR This repository is a clone of [llama.cpp](https://github.com/ggerganov/llama.cpp) with the following improvements -* Better implementation of CPU matrix multiplications (`AVX2` and `ARM_NEON`) for `fp16/fp32` and all k-, i-, and legacy `llama.cpp` quants, that leads to a significant improvement in prompt processing (PP) speed. Token generation (TG) also benefits, but to a lesser extent due to TG being memory bound -* Implementation of the [Bitnet b1.58](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) model for the CPU (`AVX2` and `ARM_NEON`) and GPU (`CUDA` and `Metal`) -* Faster CPU inferrence for MoE models +* Better implementation of CPU matrix multiplications (`AVX2` and `ARM_NEON`) for `fp16/fp32` and all k-, i-, and legacy `llama.cpp` quants, that leads to a significant improvement in prompt processing (PP) speed, typically in the range of 2X, but up to 4X for some quantization types. Token generation (TG) also benefits, but to a lesser extent due to TG being memory bound +* Faster CPU inferrence for MoE models with similar performance gains +* Implementation of the [Bitnet b1.58](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) model for the CPU (`AVX2` and `ARM_NEON`) and GPU (`CUDA` and `Metal`). This implementat6ion is much faster than the unmerged `llama.cpp` [PR-8151](https://github.com/ggerganov/llama.cpp/pull/8151) -If you are not already familiar with [llama.cpp](https://github.com/ggerganov/llama.cpp), it is better to start there. For those familiar with `llama.cpp`, everything works the same as `llama.cpp` (or at least the way `llama.cpp` worked when I last synced on June 21). +If you are not already familiar with [llama.cpp](https://github.com/ggerganov/llama.cpp), it is better to start there. For those familiar with `llama.cpp`, everything here works the same as in `llama.cpp` (or at least the way `llama.cpp` worked when I last synced on June 21). -Note that I have published some, but not all, of the code in the respository in a series of [llamafile](https://github.com/Mozilla-Ocho/llamafile) PRs ([394](https://github.com/Mozilla-Ocho/llamafile/pull/394), [405](https://github.com/Mozilla-Ocho/llamafile/pull/405), [428](https://github.com/Mozilla-Ocho/llamafile/pull/428), [435](https://github.com/Mozilla-Ocho/llamafile/pull/435), [453](https://github.com/Mozilla-Ocho/llamafile/pull/453), and [464](https://github.com/Mozilla-Ocho/llamafile/pull/464)) +Note that I have published some, but not all, of the code in this respository in a series of [llamafile](https://github.com/Mozilla-Ocho/llamafile) PRs ([394](https://github.com/Mozilla-Ocho/llamafile/pull/394), [405](https://github.com/Mozilla-Ocho/llamafile/pull/405), [428](https://github.com/Mozilla-Ocho/llamafile/pull/428), [435](https://github.com/Mozilla-Ocho/llamafile/pull/435), [453](https://github.com/Mozilla-Ocho/llamafile/pull/453), and [464](https://github.com/Mozilla-Ocho/llamafile/pull/464)) +The entire implementation is in a single C++ source file (`iqk_mul_mat.cpp`) with just two interface functions `iqk_mul_mat` (`fp16/fp32` and quantized matrix multiplications) and `iqk_mul_mat_moe` (as `iqk_mul_mat` but meant to be used for the FFN part of a MoE model). Under the hood `iqk_mul_mat_moe` uses the same implementation as `iqk_mul_mat`, with the only difference being where results are stored in memory. ## Why? @@ -20,17 +21,7 @@ Mostly out of curiosity: * Justine Tunney's `tinyBLAS`, which she contributed to `llama.cpp` in [PR 6414](https://github.com/ggerganov/llama.cpp/pull/6414), only works for `Q4_0`, `Q8_0` and `fp16/bf16` models. In the surrounding discussion about possibly extending `tinyBLAS` to k- and i-quants, she felt that k-quants are [not ammenable to block-tiling](https://github.com/ggerganov/llama.cpp/pull/6840#issuecomment-2072995387), which is required to improve performance. This statement piqued my curiosity, so here we are. * Bitnet-1.58b has been one of the [most discussed topics](https://github.com/ggerganov/llama.cpp/issues/5761#issuecomment-2198380366) in the `llama.cpp` project, so eventually I decided to see how efficiently one can implement a tertiary model -Curiosity aside, improved CPU performance may be (or may become) important in practice. According to The Register, 70% of AI inferrence [is done on the CPU](https://www.theregister.com/2024/05/30/arm_cortex_x925_ai_cores/?td=rt-3a), at least in the Android world (but I haven't come around to actually comparing performancer on a phone). With ever increasing number of LLM model parameters, and with Meta's 400B model release imminent, the CPU may become the only viable option for people not willing (or not able to) rent/buy uber expensive GPU instances capable of running such models. Granted, one would need a pretty beefy computer to run a 400B model, and inference speed will be sluggish, but at least one will not need to spend the equivalent of a luxury apartmenty in the downtown of the city where I live to buy the GPU system capable of running the model. - -## Bitnet-1.58B - -Two implementations are provided -* `IQ1_BN` - uses 1.625 bits-per-weight (bpw) -* `IQ2_BN` - uses 2.0 bpw - -`IQ2_BN` is faster for PP (CPU and GPU, although the PP performance difference on CUDA is very minor). `IQ1_BN` can arrive at a higher TG performance on the CPU (given enough threads) because of the smaller model size, but it is always slower on the GPU. - -There is the unmerged [PR 8151](https://github.com/ggerganov/llama.cpp/pull/8151) in `llama.cpp` that implements Bitnet-1.58B for the CPU (`AVX` and `ARM_NEON`). The following table compares performance between this repo and `PR-8151` in `llama.cpp`. +Curiosity aside, improved CPU performance may be (or may become) important in practice. According to The Register, 70% of AI inferrence [is done on the CPU of mobile phones](https://www.theregister.com/2024/05/30/arm_cortex_x925_ai_cores/?td=rt-3a), at least in the Android world (but I haven't come around to actually comparing performancer on a phone). With ever increasing number of LLM model parameters, and with Meta's 400B model just released, the CPU may become the only viable option for people not willing (or not able to) rent/buy uber expensive GPU instances capable of running such models. Granted, one would need a pretty beefy computer to run a 400B model, and inference speed will be sluggish, but at least one will not need to spend the equivalent of a luxury apartmenty in the downtown of the city where I live to buy the GPU system capable of running the model. ## Performance comparison to llama.cpp @@ -39,14 +30,14 @@ The results in the following tables are obtained with these parameters: * The `AVX2` CPU is a 16-core Ryzen-7950X * The `ARM_NEON` CPU is M2-Max * `tinyBLAS` is enabled in `llama.cpp` -* `llama.cpp` results are for `build: 081fe431 (3441)`, which was the current `llama.cpp` master branch master branch when I pulled on July 23 2024. -* The project is built without `CUDA` support, no `BLAS`, and Accelerate framework disabled +* `llama.cpp` results are for `build: 081fe431 (3441)`, which was the current `llama.cpp` master branch when I pulled on July 23 2024. +* The projects are built without `CUDA` support, no `BLAS`, and Accelerate framework disabled ### Prompt processing -Here I set the number of threads to be equal to the number of (performance) cores of the CPU, so 16 threads for the Ryzen-7950X and 8 threads for the M2-Max. The following table summarizes the results. To not make the table too long, I have listed only quantized models containing predominantly one quantization type (i.e., excluded the `QX_K - Medium` quants, which are typically a mix of `QX_K` and `Q(X+1)_K`, as well as `IQ2_S` and `IQ3_XS`). +Here I set the number of threads to be equal to the number of (performance) cores of the CPU, so 16 threads for the Ryzen-7950X and 8 threads for the M2-Max. The following table summarizes the results. To not make the table too long, I have listed only quantized models containing predominantly one quantization type (i.e., excluded the `QX_K - Medium/Large` variants, which are typically a mix of `QX_K` and `Q(X+1)_K`, as well as `IQ2_S` and `IQ3_XS`). -The command line to generate the data is +The command line to generate the benchmark data is ``` ./bin/llama-bench -m $model -p 512 -n 0 -t $num_threads -ngl 0 ``` @@ -195,11 +186,69 @@ The command line to generate the data was | | | | 4 | 15.89 ± 0.00 | 24.28 ± 0.29 | 1.528 | | | | | 8 | 26.56 ± 0.36 | 29.87 ± 0.08 | 1.125 | -Here gains are generally lower compared to PP due to TG performance being limited by memory bandwidth. Nevertheless, for some quants/architectures/threads the speedup is quite remarkable (e.g., almost a factor of for `Q5_1` on `AVX2` with 2 threads). +Here gains are generally lower compared to PP due to TG performance being limited by memory bandwidth. Nevertheless, for some quants/architectures/threads the speedup is quite remarkable (e.g., almost a factor of 2 for `Q5_1` on `AVX2` with 2 threads). ## MoE models -There is [PR-6840](https://github.com/ggerganov/llama.cpp/pull/6840) from Justine Tunney in `llama.cpp`, but it has not been merged since April 23, so I'll compare performance to the master branch for Mixtral-8x7B. +There is [PR-6840](https://github.com/ggerganov/llama.cpp/pull/6840) from Justine Tunney in `llama.cpp`, but it has not been merged since April 23, so I'll compare performance to the master branch for Mixtral-8x7B. As Mixtral8x7B quantization is quite a lengthy process, the following shows data only for `Q4_K_S` (a commonly used k-quant, 4 bit), `Q5_0` (a legacy quant, 5 bit), and `IQ4_XXS` (a 3-bit i-quant) + +| model | size | backend | threads | test | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | +| ------------ | ---------: | ---------- | ------: | -------: | ---------------: | ---------------: | ------: | +| 8x7B Q4_K_S | 48.75 GiB | AVX2 | 16 | pp512 | 54.92 ± 0.23 | 102.94 ± 0.37 | 1.874 | +| | | NEON | 8 | pp512 | 23.54 ± 1.56 | 38.32 ± 0.54 | 1.628 | +| | | AVX2 | 4 | tg128 | 7.80 ± 0.07 | 7.83 ± 0.09 | 1.004 | +| | | NEON | 8 | tg128 | 14.95 ± 0.25 | 15.28 ± 0.24 | 2.022 | +| 8x7B IQ3_XXS | 33.07 GiB | AVX2 | 16 | pp512 | 17.58 ± 0.04 | 68.45 ± 0.22 | 3.894 | +| | | NEON | 8 | pp512 | 7.75 ± 0.04 | 34.67 ± 0.40 | 4.474 | +| | | AVX2 | 4 | tg128 | 4.60 ± 0.01 | 5.45 ± 0.09 | 1.185 | +| | | AVX2 | 8 | tg128 | 8.04 ± 0.65 | 9.83 ± 0.06 | 1.223 | +| | | AVX2 | 16 | tg128 | 10.42 ± 0.01 | 10.57 ± 0.01 | 1.014 | +| | | NEON | 8 | tg128 | 6.19 ± 1.16 | 7.27 ± 0.14 | 1.174 | +| 8x7B Q5_0 | 59.11 GiB | AVX2 | 16 | pp512 | 29.06 ± 0.43 | 62.67 ± 0.32 | 2.157 | +| | | NEON | 8 | pp512 | 15.17 ± 0.51 | 27.36 ± 1.03 | 1.804 | +| | | AVX2 | 4 | tg128 | 5.44 ± 0.10 | 6.81 ± 0.06 | 1.252 | +| | | NEON | 8 | tg128 | 12.03 ± 0.77 | 12.41 ± 1.27 | 1.032 | + + +## Bitnet-1.58B + +Two implementations are provided +* `IQ1_BN` - uses 1.625 bits-per-weight (bpw) +* `IQ2_BN` - uses 2.0 bpw + +`IQ2_BN` is faster for PP (CPU and GPU, although the PP performance difference on CUDA is very minor). `IQ1_BN` can arrive at a higher TG performance on the Ryzen-7950X (given enough threads) because of the smaller model size, but it is always slower on the GPU and on the M2-Max CPU. + +There is the unmerged [PR 8151](https://github.com/ggerganov/llama.cpp/pull/8151) in `llama.cpp` that implements Bitnet-1.58B for the CPU (`AVX` and `ARM_NEON`, no GPU implementation). The following table compares performance between this repo and `PR-8151` in `llama.cpp`. + +| model | size | backend | threads | test | t/s (llama.cpp) | t/s (this repo)| Speedup | +| --------------------- | ---------: | ---------- | ------: | -----: | ---------------: | -------------: | ------: | +| bitnet 3B - 1.625 bpw | 729.64 MiB | AVX2 | 16 | pp512 | 120.61 ± 0.48 | 407.06 ± 0.80 | 3.380 | +| | | NEON | 8 | pp512 | 46.64 ± 0.02 | 205.90 ± 0.88 | 4.415 | +| | | CUDA | 8 | pp512 | - | 9655.14 ± 21.01| - | +| | | Metal | 8 | pp512 | - | 697.59 ± 2.12 | - | +| | | AVX2 | 2 | tg128 | 15.79 ± 0.01 | 22.13 ± 0.02 | 1.402 | +| | | AVX2 | 4 | tg128 | 28.64 ± 1.72 | 40.14 ± 0.04 | 1.402 | +| | | AVX2 | 8 | tg128 | 48.91 ± 0.08 | 57.76 ± 2.86 | 1.181 | +| | | AVX2 | 16 | tg128 | 57.73 ± 0.05 | 60.14 ± 0.04 | 1.042 | +| | | NEON | 2 | tg128 | 11.43 ± 0.04 | 16.87 ± 0.02 | 1.476 | +| | | NEON | 4 | tg128 | 21.11 ± 0.05 | 30.66 ± 0.11 | 1.452 | +| | | NEON | 8 | tg128 | 37.36 ± 0.07 | 55.21 ± 0.16 | 1.478 | +| | | CUDA | 8 | tg128 | - | 229.21 ± 0.89 | - | +| | | Metal | 8 | tg128 | - | 69.33 ± 0.07 | - | +| bitnet 3B - 2.000 bpw | 873.65 MiB | AVX2 | 16 | pp512 | 151.39 ± 0.35 | 512.79 ± 2.58 | 3.387 | +| | | NEON | 8 | pp512 | 46.54 ± 0.03 | 242.05 ± 0.34 | 5.201 | +| | | CUDA | 8 | pp512 | - | 9810.91 ± 25.00| - | +| | | Metal | 8 | pp512 | - | 722.66 ± 0.47 | - | +| | | AVX2 | 2 | tg128 | 18.93 ± 0.02 | 37.42 ± 0.07 | 1.978 | +| | | AVX2 | 4 | tg128 | 34.54 ± 0.06 | 53.25 ± 0.02 | 1.542 | +| | | AVX2 | 8 | tg128 | 52.97 ± 0.07 | 52.06 ± 0.08 | 0.983 | +| | | AVX2 | 16 | tg128 | 51.84 ± 0.25 | 52.98 ± 0.03 | 1.022 | +| | | NEON | 2 | tg128 | 11.40 ± 0.02 | 32.01 ± 0.27 | 2.808 | +| | | NEON | 4 | tg128 | 20.99 ± 0.00 | 56.45 ± 0.11 | 2.689 | +| | | NEON | 8 | tg128 | 37.28 ± 0.08 | 89.77 ± 0.70 | 2.408 | +| | | CUDA | 8 | tg128 | - | 241.34 ± 0.27 | - | +| | | Metal | 8 | tg128 | - | 95.22 ± 0.55 | - | + ## To tile or not to tile |