summaryrefslogtreecommitdiff
path: root/llama.cpp
AgeCommit message (Collapse)Author
2024-03-08llama : support Mamba Selective State Space Models (#5328)compilade
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
2024-03-08llama : fix quantization of shared token_embd (#5944)compilade
2024-03-08llama : assume tied weights if lm_head/output weights is missing (#5824)Don Mahurin
This is to support model configurations with "tie_word_embeddings" set to true. Co-authored-by: Don Mahurin <2797413+dmahurin@users.noreply.github.com>
2024-03-07Revert "[SYCL] fix error when set main gpu to non-zero (#5901)" (#5918)Neo Zhang Jianyu
This reverts commit ceca1aef0738b57951cd12c603c3477e75312dec.
2024-03-07server : refactor (#5882)Georgi Gerganov
* server : refactoring (wip) * server : remove llava/clip objects from build * server : fix empty prompt handling + all slots idle logic * server : normalize id vars * server : code style * server : simplify model chat template validation * server : code style * server : minor * llama : llama_chat_apply_template support null buf * server : do not process embedding requests when disabled * server : reorganize structs and enums + naming fixes * server : merge oai.hpp in utils.hpp * server : refactor system prompt update at start * server : disable cached prompts with self-extend * server : do not process more than n_batch tokens per iter * server: tests: embeddings use a real embeddings model (#5908) * server, tests : bump batch to fit 1 embedding prompt * server: tests: embeddings fix build type Debug is randomly failing (#5911) * server: tests: embeddings, use different KV Cache size * server: tests: embeddings, fixed prompt do not exceed n_batch, increase embedding timeout, reduce number of concurrent embeddings * server: tests: embeddings, no need to wait for server idle as it can timout * server: refactor: clean up http code (#5912) * server : avoid n_available var ggml-ci * server: refactor: better http codes * server : simplify json parsing + add comment about t_last * server : rename server structs * server : allow to override FQDN in tests ggml-ci * server : add comments --------- Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
2024-03-07[SYCL] fix error when set main gpu to non-zero (#5901)Neo Zhang Jianyu
* fix error when set main gpu to non-zero * fix delete condition
2024-03-05Vulkan Improvements (#5835)0cc4m
* Improve dequant shaders, add fast q4_0 dequant * Optimize dmmv non-kquants for GCN Remove unnecessary SPIR-V shader duplication * Fix q4_0 dequant dispatch sizes Fix backend free bug * Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0 * Add unary and binary op shader templates * Fix Vulkan check results * Enable non-contiguous support for simple ops * Add argsort Basic q4_0 mmq shader and unit test * Speed up q4_0 dequant code, enable mmq for q4_0 * Rework matmul pipeline selection * Add soft_max alibi support * Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders * Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency
2024-03-04llama : fix embeddings (#5796)Georgi Gerganov
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
2024-03-04add alias for chat template (#5858)Xuan Son Nguyen
2024-03-03llama : allow for user specified embedding pooling type (#5849)Douglas Hanley
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-03llama : fix llama_copy_state_data with fragmented KV cache (#5840)compilade
The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems.
2024-03-02llama : add abort_callback to interrupt computation (#5409)Michael Podvitskiy
* using abort_callback from ggml to stop llama computation * format fix * a brief explaining comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-02llama : refactor internal quantization functions (#5830)Xuan Son Nguyen
2024-03-02llama : fix segfault from unknown model arch name (#5820)compilade
* llama : fix segfault from unknown model arch name * llama : make all LLM maps const This also requires using `std::map::at` instead of its `operator[]` which does not exist for const maps. * llama : name LLM_ARCH_UNKNOWN to "(unknown)" This avoids errors from `std::map::at` when getting the general name of the model architecture. Using "(unknown)" instead of an empty string as per suggestion https://github.com/ggerganov/llama.cpp/pull/5820#issuecomment-1973735284 * llama : remove redundant inner const for LLM_TENSOR_NAMES The extra const won't do anything here as const maps return const references to values. Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * llama : remove redundant nullptr check in llm_arch_from_string Since LLM_ARCH_NAMES is a const map, no spurious elements with a NULL name are inserted anymore, so this check is dead code. --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
2024-03-02Support multiple GPUs (split mode) on SYCL backend (#5806)Neo Zhang Jianyu
* suport multiple cards: split-mode - layer|row * rm warning * rebase with master, support tow new OPs, close feature for -sm=row, fix for unit test * update news * fix merge error * update according to review comments
2024-03-01llama : add StarCoder2 support (#5795)Sourab Mangrulkar
* Add support for starcoder2 * handle rope type * skip rope freq and rotary embeddings from being serialized * resolve comments * Update llama.cpp * remove redundant changes * handle `rope-theta` * llama : change starcoder2 rope type * address comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-01llama : cleanup unused mmq flags (#5772)Pierrick Hymbert
* cleanup unused --no-mul-mat-q,-nommq, -mmq, --mul-mat-q, mul_mat_q * remove: mul_mat_q in compare llama bench and usage * update llama-bench --------- Co-authored-by: slaren <slarengh@gmail.com>
2024-03-01unicode : switch to multimap based nfd_map (#5799)Douglas Hanley
* switch to multimap based nfd_map due to compile time issues * simplify multimap keys * dont construct new locale every time
2024-02-29llama : constified `llama_set_state_data`'s `src` (#5774)Marcus Dunn
2024-02-28llama : remove deprecated API (#5770)Georgi Gerganov
ggml-ci
2024-02-28llama : fix non-quantization of expert gating tensors (#5754)compilade
This reverts a single line from #5475
2024-02-28llama : improve BERT tokenization (#5740)Douglas Hanley
* implement nfd for stripping accents in wpm tokenizer * sort nfd map; reuse iterator * use builtin tolower * add locale include * Simplify to_lower cases Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> --------- Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
2024-02-27IQ4_XS: a 4.25 bpw quantization (#5747)Kawrakow
* Try IQ4_NL with blocks of 64 - does not look good * iq4_xs: go to super-blocks of 256 and 6-bit scales for blocks of 32 * iq4_xs: CUDA works - 133.2 t/s * iq4_xs: AVX2 dot product * iq4_xs: ARM_NEON dot product * iq4_nl: Metal implementation As usual, Metal / Apple Silicon don't like my quants. * iq3_xs: minor fix * iq4_xs: shrink by using IQ3_S for attn_k and attn_q * iq4_xs: revert using IQ3_S for attn_k and attn_v PPL vs size is good, but CPU performance suffers: on M2 Max TG-128 drops to 21.7 t/s from 28.8, and on a Ryzen-7950X to 14.5 t/s from 15.8 t/s. On CUDA we have 135 t/s when using IQ3_S vs 133 t/s with pure IQ4_XS. * Fix CI * iq4_xs: Added forgotten check for 256 divisibility --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-02-27llama : fix defrag bugs + add parameter (#5735)Georgi Gerganov
* llama : fix defrag bugs + enable by default ggml-ci * llama : add defrag_thold parameter ggml-ci * llama : cont * llama : disable log message ggml-ci * llama : fix graph size check during defrag
2024-02-26Adding IQ2_S and IQ2_M to complete coverage of the 2-3 bit quantization ↵Kawrakow
range (#5721) * Adding IQ2_S and IQ2_M as a single cumulative commit * Update examples/quantize/quantize.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-02-26[SYCL] Add support for soft_max ALiBi (#5639)AidanBeltonS
* Add support for bias * Update pre-processor * rm commented code * fix format * fix CI --------- Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
2024-02-26llama : fix Gemma rope type (#5691)Georgi Gerganov
2024-02-25llama : refactor k-shift implementation + KV defragmentation (#5691)Georgi Gerganov
* llama : refactor k-shift implementation ggml-ci * llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add * llama : cont k-shift refactoring + normalize type names ggml-ci * minor : fix MPI builds * llama : reuse n_rot from the build context ggml-ci * llama : revert enum name changes from this PR ggml-ci * llama : update llama_rope_type * llama : add comment about rope values * llama : fix build * passkey : apply kv cache updates explicitly ggml-ci * llama : change name to llama_kv_cache_update() * llama : add llama_kv_cache_seq_pos_max() * passkey : fix llama_kv_cache_seq_pos_max() usage * llama : some llama_kv_cell simplifications * llama : add llama_kv_cache_compress (EXPERIMENTAL) * llama : add alternative KV cache merging (EXPERIMENTAL) * llama : add llama_kv_cache_defrag * llama : comments * llama : remove llama_kv_cache_compress will add in a separate PR ggml-ci * llama : defragment via non-overlapping moves * llama : ggml_graph based defrag implementation ggml-ci * llama : switch the loop order in build_defrag * llama : add comments
2024-02-25code : normalize enum names (#5697)Georgi Gerganov
* coda : normalize enum names ggml-ci * code : cont * code : cont
2024-02-24IQ3_S: a much better alternative to Q3_K (#5676)Kawrakow
* iq4_nl: squash commits for easier rebase * Basics (quantize, dequantize) * CUDA dequantize and dot product * Slightly faster CUDA dot product (120 t/s) * Switch to 6-bit scales * Scalar dot product * AVX2 dot product * ARM_NEON dot product * Works on metal, but still slow * Slightly better Metal dot product * Another small Metal improvement * Metal dot product is getting there * Faster CUDA dot product * Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided * Report the actual bpw * Add _xs mix that is 4.05 bpw for non-MoE models * Remove IQ4_XS for now, slightly adjust kvalues_iq4nl * AVX2 dot product uses Q8_0 instead of Q8_K * Add to test-backend-ops * Minor fix * Also use use Q5_K for attn_output in MoE models * Fixes after merging latest master * Switching to blocks of 32 * AVX2 for blocks of 32 * Scaler dot product for blocks of 32 * ARM_NEON dot product for blocks of 32 * Metal kernels for blocks of 32 * Slightly faster Metal kernels * Resurrecting iq3_xs After all the experimentation, nothing was better than this. * Minor PPL improvement via a block scale fudge factor * Minor improvement via 3 neighbours * iq3_xs: working scalar and AVX2 dot products * iq3_xs: ARM_NEON dot product - works but extremely slow (10 t/s) * iq3_xs: working Metal implementation * Adding IQ3_M - IQ3_XS mix with mostly Q4_K * iiq3_xs: a 3.4375 bpw variant * iq3_xs: make CUDA work for new version * iq3_xs: make scalar and AVX2 work for new version * iq3_s: make ARM_NEON work with new version * iq3_xs: make new version work on metal Performance is very similar to Q3_K_S * iq3_xs: tiny Metal speed improvement * iq3_xs: tiny Metal speed improvement * Fix stupid warning * Q3_K_XS now uses a mix of IQ3_XS and IQ3_XXS * iq3_xs: rename to iq3_s * iq3_s: make tests pass * Move Q3_K_XS mix to 3.25 bpw * Attempt to fix failing tests * Another attempt to fix the Windows builds * Attempt to fix ROCm * ROCm again * iq3_s: partial fix for QK_K = 64 * iq3_s: make it work on metal for QK_K = 64 Pleasent surprise: the coding was super-block size independent, so all it took was to delete some QK_K == 256 guards. * Will this fix ROCm? --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-02-22mpt : do not duplicate token_embd.weight on disk (#5670)Jared Van Bortel
2024-02-22gemma : use more bits for the token_embd.weight tensor (#5650)Georgi Gerganov
* gemma : use Q8_0 for the token_embd.weight tensor * llama : quantize token_embd.weight using output type
2024-02-22py : add Gemma conversion from HF models (#5647)Georgi Gerganov
* py : add gemma conversion from HF models * Update convert-hf-to-gguf.py Co-authored-by: Aarni Koskela <akx@iki.fi> * Update convert-hf-to-gguf.py Co-authored-by: Aarni Koskela <akx@iki.fi> * Update convert-hf-to-gguf.py Co-authored-by: Jared Van Bortel <jared@nomic.ai> --------- Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
2024-02-22Add Gemma chat template (#5665)Xuan Son Nguyen
* add gemma chat template * gemma: only apply system_prompt on non-model message
2024-02-22minor : fix trailing whitespace (#5638)Georgi Gerganov
2024-02-22server : fallback to chatml, add AlphaMonarch chat template (#5628)Xuan Son Nguyen
* server: fallback to chatml * add new chat template * server: add AlphaMonarch to test chat template * server: only check model template if there is no custom tmpl * remove TODO
2024-02-22mpt : add optional bias tensors (#5638)Dat Quoc Nguyen
Update for MPT with optional bias parameters: to work with PhoGPT and SEA-LION models that were pre-trained with 'bias'.
2024-02-22llama : fix loading models with shared tok_embd and output (#5651)slaren
ggml-ci
2024-02-21llama : fix session save/load with quantized KV (#5649)slaren
2024-02-21gemma : allow offloading the output tensor (#5646)slaren
2024-02-21llama : add `gemma` model (#5631)postmasters
There are couple things in this architecture: 1. Shared input and output embedding parameters. 2. Key length and value length are not derived from `n_embd`. More information about the models can be found at https://ai.google.dev/gemma. GGUFs can be downloaded from https://huggingface.co/google.
2024-02-21IQ4_NL: 4-bit non-linear quants with blocks of 32 (#5590)Kawrakow
* iq4_nl: squash commits for easier rebase * Basics (quantize, dequantize) * CUDA dequantize and dot product * Slightly faster CUDA dot product (120 t/s) * Switch to 6-bit scales * Scalar dot product * AVX2 dot product * ARM_NEON dot product * Works on metal, but still slow * Slightly better Metal dot product * Another small Metal improvement * Metal dot product is getting there * Faster CUDA dot product * Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided * Report the actual bpw * Add _xs mix that is 4.05 bpw for non-MoE models * Remove IQ4_XS for now, slightly adjust kvalues_iq4nl * AVX2 dot product uses Q8_0 instead of Q8_K * Add to test-backend-ops * Minor fix * Also use use Q5_K for attn_output in MoE models * Fixes after merging latest master * Switching to blocks of 32 * AVX2 for blocks of 32 * Scaler dot product for blocks of 32 * ARM_NEON dot product for blocks of 32 * Metal kernels for blocks of 32 * Slightly faster Metal kernels * iq4_nl: Fix after merging with master * iq4_nl: another fix after merging with master * Use IQ4_NL instead of Q4_K when using k-quants is not possible * Fix typo that makes several tests fail * It was the ggml_vdotq thing missed inside the brackets --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-02-20Server: use llama_chat_apply_template (#5593)Xuan Son Nguyen
* server: use llama_chat_apply_template * server: remove trailing space * server: fix format_chat * server: fix help message Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: fix formatted_chat --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-02-19minor : fix trailing whitespace (#5538)Georgi Gerganov
2024-02-19llama : add llama_chat_apply_template() (#5538)Xuan Son Nguyen
* llama: add llama_chat_apply_template * test-chat-template: remove dedundant vector * chat_template: do not use std::string for buffer * add clarification for llama_chat_apply_template * llama_chat_apply_template: add zephyr template * llama_chat_apply_template: correct docs * llama_chat_apply_template: use term "chat" everywhere * llama_chat_apply_template: change variable name to "tmpl"
2024-02-181.5 bit quantization (#5453)Kawrakow
* iq1_s: WIP basics * iq1_s: CUDA is working * iq1_s: scalar CPU dot product * iq1_s: WIP AVX2 dot product - something is not right * Fix tests * Fix shadow warnings * Fix after merge with latest master * iq1_s: AVX2 finally works * iq1_s: ARM_NEON dot product. Works, but not very fast * iq1_s: better grid * iq1_s: use IQ2_XXS for attn_output At a cost of 0.04 extra bpw this gives a big improvement in PPL. * iq1_s: Metal basics Dequantize works, but not dot product * iq1_s: Metal works, but quite slow As usual, Apple Silicon does not like the code I write. * iq1_s: Tests * iq1_s: slightly faster dot product --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2024-02-17ggml : add ALiBi support for ggml_soft_max_ext (#5488)Georgi Gerganov
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
2024-02-16llama : minor fixed return int value (#5529)Herman Semenov
2024-02-16ggml : add numa options (#5377)bmwl
* Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverted Makefile * Fixed include * Removed sched.h from ggml.h, moved ggml_get_numa_affinity into ggml.c, removed trailing whitespace and fixed up a few inconsistent variables * removed trailing whitespace * Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverting Makefile * Fixed a number of issues with the move from BOOL to ggml_numa_strategies. Added a note about mirror mode note being implemented yet * Removing MIRROR_MODE code for this PR * Removing last bit of MIRROR_MODE code for this PR * Removing unneeded branch in server.cpp example and moving get_numa_affinity and making it static * Fixed lingering init_llama_backend() bool calls in tests and examples * Remote enum llama_numa_strategies * Revert bad merge with dynatemp flags * add missing enum ggml_numa_strategies declaration and revert sync problem with master * add missing enum ggml_numa_strategies declaration * fixed ggml_init_numa variable * Update ggml.h Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * Update READMEs with info about numa flags, change INTERLEAVE strategy name to DISTRIBUTE everywhere, implement the improved distribution strategy from @rankaiyx, fix a spelling mistake and un-merge some bad merges * split numa init out from llama_backend_init and created llama_numa_init. Updated all code paths and samples * Fix up some boolean vs enum comparisons * Added #ifdefs for non-Linux OS that don't have cpu_set_t datatype * Update ggml.h Align enum values Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c Remove whitespace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c align paremeters Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update examples/server/server.cpp remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/common.cpp Remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * unified ggml_numa_strategy enum and fixed text alignment in server.cpp example * Update ggml.c simplified return for platforms without NUMA support Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * removed redundant else from cli argument processing of --numa * whitespace --------- Co-authored-by: root <root@nenya.lothlorien.ca> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
2024-02-15Use correct type of pooling for embedding models (#5500)Douglas Hanley
Use correct type of pooling for embedding models