summaryrefslogtreecommitdiff
path: root/ggml-cuda
AgeCommit message (Collapse)Author
2024-07-27Merge mainline llama.cpp (#3)Kawrakow
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-07-24Add copyright noticesIwan Kawrakow
Only on the files where I have contributed in a significant way, or the files I wrote myself.
2024-07-17iq1bn(no lookup): better versionIwan Kawrakow
We have 4 groups of 16 in a block of 64 quants. For each group of 16 we have 3 groups of 5, each using 8 bits. The remaining 16'th quants of the 4 groups of 16 are encoded with 8 bits using the same encoding as the groups of 5. The only kernel where we have complications is the CUDA dequantize kernel (because we are dequantizing 8 quants there, and we have different encoding for the 1st and 2nd group of 8 in a group of 16). Ths achieves better performance on all tested platforms than any previous 1.625 bpw attempt. We have: | model | size | params | backend | threads | test | t/s | | ---------------- | ---------: | ---------: | ---------- | ------: | ------------: | ---------------: | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | pp512 | 9613.02 ± 24.54 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | CUDA | 8 | tg128 | 229.85 ± 0.33 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | pp512 | 322.59 ± 1.00 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 16 | tg128 | 59.79 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 8 | tg128 | 57.62 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 4 | tg128 | 33.66 ± 0.29 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | AVX2 | 2 | tg128 | 18.30 ± 0.01 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | pp512 | 698.13 ± 0.21 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | Metal | 8 | tg128 | 68.88 ± 0.24 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | pp512 | 196.80 ± 0.50 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 8 | tg128 | 51.58 ± 0.41 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 4 | tg128 | 30.80 ± 0.03 | | 1.625 bpw Bitnet | 729.64 MiB | 3.32 B | NEON | 2 | tg128 | 16.89 ± 0.01 | It is still slower than 2 bpw Bitnet, but the difference now is not as dramatic.
2024-07-15iq1bn(no lookup): CUDAIwan Kawrakow
Not good. We only get ~160 t/s.
2024-06-25bitnet: remove iq1_bn lookup table storing +/- signsIwan Kawrakow
The AVX2 implementation was the only one left using it, so I decided to see if we can get a performant implementation using the 0,1,2 lookup table. Turns out we can, and it is even slightly faster than the sign based table. We now get PP-512 = 275 t/s and TG-128 = 57.7 t/s with 16 threads on the Ryzen-7950X. With only one lookup table left for iq1_bn, I renamed it to iq1bn_grid_u16.
2024-06-25Bitnet: trying an alternative iq1_bn gridIwan Kawrakow
Faster on CUDA. The scalar version is faster too. The issue with CUDA is that now I see wild performance fluctuations. Running llama-bench I can get 220 t/s for TG-128 one time, and 190 t/s another time, with uncertaintiers of 1-2 t/s. Same for PP, results are jumping back-and-fort between ~9500 t/s and ~8900 t/s. So, basically no reliable measurement at this point, but for sure faster than the previous version, which was at around 170-180 t/s.
2024-06-22bitnet(scale in a separate tensor): mul -> scale on CUDAIwan Kawrakow
On CUDA we do not have access to the tensor data until we hit the kernel. That's why this hack. In any case, iq2_bn goes back up to 228 t/s, which is close to the 234 t/s we have without the extra scale operation. PP is 9400 t/s, down from 9600 t/s, but better than the 9200 t/s we get without making the mul -> scale replacement.
2024-06-22bitnet(scale in a separate tensor): CUDAIwan Kawrakow
2024-06-22Bitnet(1.75 bpw): higher precision fp8 scaleIwan Kawrakow
Use 3 bits for the exponent and 5 bits for the mantissa. This makes PPL to be the same as fp16 (but the previous version with 4 bits for the exponent and mantissa was good enough for any practical purposes).
2024-06-22Bitnet(1.75 bpw): slightly faster CUDA dot productIwan Kawrakow
We get 205 t/s, so ~13% slower than 2 bit.
2024-06-22Bitnet(2.25 bpw): CUDAIwan Kawrakow
We get PP-512 = 9600 t/s, TG-128 = 234 t/s (but we need to use 8 CPU threads, else results are lower, so clearly there is something being computed on the CPU). PP-512 is very close to PP-512(fp16) = 9800 t/s
2024-06-22bitnet: CUDA, scalar, AVX2Iwan Kawrakow
2024-06-20CUDA: stream-k decomposition for MMQ (#8018)Johannes Gäßler
* CUDA: stream-k decomposition for MMQ * fix undefined memory reads for small matrices
2024-06-17Add support for sqrt on CUDA (#7953)Calvin Laurenson
* cuda sqrt support * enable cuda in pca * fix comments in pca * add test * add sqrt to ggml_backend_cuda_supports_op * fix test * new line * Use F32 sqrtf instead of F64 sqrt Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-06-16cuda : fix bounds check for src0 rows in MMVQ kernel (whisper/2231)Georgi Gerganov
* cuda : fix bounds check for src0 rows in MMVQ kernel * Update ggml-cuda/mmvq.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-06-14CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (#7921)Johannes Gäßler
* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores * try CI fix * try CI fix * try CI fix * fix data race * rever q2_K precision related changes
2024-06-12CUDA: fix broken oob check for FA vec f32 kernel (#7904)Johannes Gäßler
2024-06-12tests : add non-cont unary tests (#7857)Georgi Gerganov
* tests : add non-cont unary tests * ggml : update unary asserts and "supports_op" ggml-ci
2024-06-11CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (#7860)Johannes Gäßler
2024-06-10CUDA: use tensor cores for MMQ (#7676)Johannes Gäßler
* CUDA: int8 tensor cores for MMQ (legacy quants) * fix out-of-bounds writes * __builtin_assume -> GGML_CUDA_ASSUME * fix writeback returning too early
2024-06-09CUDA: revise q8_1 data layout for mul_mat_q (#7824)Johannes Gäßler
2024-06-05CUDA: refactor mmq, dmmv, mmvq (#7716)Johannes Gäßler
* CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits
2024-06-05ggml : refactor rope norm/neox (#7634)Georgi Gerganov
* ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci
2024-06-01Fix FlashAttention debug test, FP32 assert (#7684)Johannes Gäßler
2024-06-01CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)Johannes Gäßler
2024-06-01CUDA: quantized KV support for FA vec (#7527)Johannes Gäßler
* CUDA: quantized KV support for FA vec * try CI fix * fix commented-out kernel variants * add q8_0 q4_0 tests * fix nwarps > batch size * split fattn compile via extern templates * fix flake8 * fix metal tests * fix cmake * make generate_cu_files.py executable * add autogenerated .cu files * fix AMD * error if type_v != FP16 and not flash_attn * remove obsolete code
2024-05-29ggml : fix YARN + add tests + add asserts (#7617)Georgi Gerganov
* tests : add rope tests ggml-ci * ggml : fixes (hopefully) ggml-ci * tests : add non-cont tests ggml-ci * cuda : add asserts for rope/norm + fix DS2 ggml-ci * ggml : assert contiguousness * tests : reduce RoPE tests ggml-ci
2024-05-29cuda : non-cont concat support (#7610)Georgi Gerganov
* tests : add non-cont concat tests * cuda : non-cont concat support ggml-ci
2024-05-28ggml : generalize GGML_OP_CONCAT (#7563)Georgi Gerganov
* ggml : generalize GGML_OP_CONCAT (WIP) ggml-ci * tests : add dim != 2 tests * metal : generalize concat kernel * tests : naming * cuda : generalize concat kernel ggml-ci * sycl : add warning and assert * ggml : fix op params handling * metal : bugfix kernel ggml-ci * ggml : reimplement CPU and Metal * cuda : add asserts ggml-ci * ggml : fix ptrs ggml-ci
2024-05-28update HIP_UMA #7399 (#7414)Djip007
* update HIP_UMA #7399 add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable. - get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103) * simplify code, more consistent style --------- Co-authored-by: slaren <slarengh@gmail.com>
2024-05-23ggml : drop support for QK_K=64 (#7473)Georgi Gerganov
* ggml : drop support for QK_K=64 ggml-ci * opencl : restore QK_K=256 define
2024-05-23CUDA: fix FA out-of-bounds reads (#7479)Johannes Gäßler
2024-05-22CUDA: fix FA out-of-bounds writes (#7465)Johannes Gäßler
2024-05-22cuda : fix compile warning (#7454)Georgi Gerganov
2024-05-22CUDA: remove incorrect precision check (#7454)Johannes Gäßler
2024-05-22cuda : fix rope + add tests (#7452)Georgi Gerganov
* cuda : fix rope pos data ggml-ci * ggml : drop mode & 1 == 1 support for ggml_rope ggml-ci * ggml : support freq_factors for f16 rope (CPU) ggml-ci * tests : add rope tests using frequency factors ggml-ci
2024-05-21llama : add phi3 128K model support (#7225)liuwei-git
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-05-21CUDA: fix unused warning in mmq.cu (#7442)Johannes Gäßler
2024-05-21CUDA: deduplicate mmq code (#7397)Johannes Gäßler
2024-05-18CUDA: deduplicate FlashAttention code (#7352)Johannes Gäßler
2024-05-18cuda : add half2 __shfl_xor() for ROCm 5.5 (#7263)Engininja2
2024-05-17CUDA: faster large batch FA without tensor cores (#7314)Johannes Gäßler
2024-05-15ggml : add `ggml_upscale_ext` (ggml/814)John Balis
* initial commit with CPU implementation of upscale to shape and test, cuda implementation next * experimental commit to see if dst shape is correct * test version * test * removed unnecessary params * refactor * fixed tests * ggml : metal impl + cleanup + sycl dev warnings * patched ggml_upscale cuda op to handle non-contiguous tensors, added test for non-contiguous behavior * metal : fix upsacle op to support nb00 + style --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-05-12CUDA: add FP32 FlashAttention vector kernel (#7188)Johannes Gäßler
* CUDA: add FP32 FlashAttention vector kernel * fixup! CUDA: add FP32 FlashAttention vector kernel * fixup! fixup! CUDA: add FP32 FlashAttention vector kernel * fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
2024-05-11feat: implemented sigmoid function (ggml/806)Justina Cho
* added sigmoid function * implemented metal kernel for sigmoid * implemented cuda kernel for sigmoid * added sigmoid unary op and incremented count
2024-05-11ggml : full ALiBi support (#7192)Georgi Gerganov
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
2024-05-09CUDA: generalize FP16 fattn vec kernel (#7061)Johannes Gäßler
* CUDA: generalize FP16 fattn vec kernel * disable unsupported head sizes for AMD in test * try AMD fix * fix batch size 2-8 * partially revert changes
2024-05-08Introduction of CUDA Graphs to LLama.cpp (#6766)agray3
* DRAFT: Introduction of CUDA Graphs to LLama.cpp * FIx issues raised in comments * Tidied to now only use CUDA runtime (not mixed with driver calls) * disable for multi-gpu and batch size > 1 * Disable CUDA graphs for old GPU arch and with env var * added missing CUDA_CHECKs * Addressed comments * further addressed comments * limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake * Added more comprehensive graph node checking * With mechanism to fall back if graph capture fails * Revert "With mechanism to fall back if graph capture fails" This reverts commit eb9f15fb6fcb81384f732c4601a5b25c016a5143. * Fall back if graph capture fails and address other comments * - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS - rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS - updated Makefile build to enable CUDA graphs - removed graph capture failure checking in ggml_cuda_error using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context - fixed several resource leaks - fixed issue with zero node graphs - changed fixed size arrays to vectors - removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed - removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row - changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX - code style fixes - things to look into - VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional - possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes * fix build without cuda graphs * remove outdated comment * replace minimum cc value with a constant --------- Co-authored-by: slaren <slarengh@gmail.com>
2024-05-01CUDA: CUDART < 11.7 workaround for __hmax, __hmax2 (#7019)Johannes Gäßler
2024-04-30ggml : add Flash Attention (#5021)Georgi Gerganov
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>