summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorcompilade <113953597+compilade@users.noreply.github.com>2024-03-08 17:31:00 -0500
committerGitHub <noreply@github.com>2024-03-08 17:31:00 -0500
commitc2101a2e909ac7c08976d414e64e96c90ee5fa9e (patch)
tree0cbd7770adf3c51f7f3a252916abdbf8876daf75 /llama.cpp
parent515f7d0d4fce41c752fc253acf30707c3be2531e (diff)
llama : support Mamba Selective State Space Models (#5328)
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp698
1 files changed, 660 insertions, 38 deletions
diff --git a/llama.cpp b/llama.cpp
index 4a20b792..8c147a42 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -213,6 +213,7 @@ enum llm_arch {
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_STARCODER2,
+ LLM_ARCH_MAMBA,
LLM_ARCH_UNKNOWN,
};
@@ -241,6 +242,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
+ { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -284,6 +286,11 @@ enum llm_kv {
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
+ LLM_KV_SSM_INNER_SIZE,
+ LLM_KV_SSM_CONV_KERNEL,
+ LLM_KV_SSM_STATE_SIZE,
+ LLM_KV_SSM_TIME_STEP_RANK,
+
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST,
LLM_KV_TOKENIZER_TOKEN_TYPE,
@@ -342,6 +349,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
+ { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
+ { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
+ { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
+ { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
+
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
@@ -399,6 +411,13 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
+ LLM_TENSOR_SSM_IN,
+ LLM_TENSOR_SSM_CONV1D,
+ LLM_TENSOR_SSM_X,
+ LLM_TENSOR_SSM_DT,
+ LLM_TENSOR_SSM_A,
+ LLM_TENSOR_SSM_D,
+ LLM_TENSOR_SSM_OUT,
};
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -802,6 +821,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_MAMBA,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
+ },
+ },
+ {
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -1613,6 +1648,12 @@ struct llama_hparams {
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
+ // for State Space Models
+ uint32_t ssm_d_conv = 0;
+ uint32_t ssm_d_inner = 0;
+ uint32_t ssm_d_state = 0;
+ uint32_t ssm_dt_rank = 0;
+
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
@@ -1641,6 +1682,11 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
+ if (this->ssm_d_conv != other.ssm_d_conv) return true;
+ if (this->ssm_d_inner != other.ssm_d_inner) return true;
+ if (this->ssm_d_state != other.ssm_d_state) return true;
+ if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
+
const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
@@ -1652,6 +1698,9 @@ struct llama_hparams {
}
uint32_t n_gqa() const {
+ if (n_head_kv == 0) {
+ return 0;
+ }
return n_head/n_head_kv;
}
@@ -1662,6 +1711,18 @@ struct llama_hparams {
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
return n_embd_head_v * n_head_kv;
}
+
+ uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
+ // corresponds to Mamba's conv_states size
+ // TODO: maybe support other convolution strides than 1
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+ }
+
+ uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
+ // corresponds to Mamba's ssm_states size
+ return ssm_d_state * ssm_d_inner;
+ }
};
struct llama_cparams {
@@ -1739,11 +1800,27 @@ struct llama_layer {
struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3
struct ggml_tensor * ffn_act;
+
+ // mamba proj
+ struct ggml_tensor * ssm_in;
+ struct ggml_tensor * ssm_x;
+ struct ggml_tensor * ssm_dt;
+ struct ggml_tensor * ssm_out;
+
+ // mamba
+ struct ggml_tensor * ssm_conv1d;
+ struct ggml_tensor * ssm_a;
+ struct ggml_tensor * ssm_d;
+
+ // mamba bias
+ struct ggml_tensor * ssm_conv1d_b;
+ struct ggml_tensor * ssm_dt_b;
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
+ int32_t src = 0; // used by recurrent state models to copy states
std::set<llama_seq_id> seq_id;
@@ -1764,6 +1841,9 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
+ bool do_copy = false;
+ // with recurrent state models, a cell can hold the state for more than one past token
+ bool recurrent = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
@@ -2003,11 +2083,14 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
- struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
- struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
- struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
+ struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
+ struct ggml_tensor * inp_s_mask; // F32 [kv_size]
+ struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
@@ -2023,25 +2106,42 @@ static bool llama_kv_cache_init(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
- uint32_t n_ctx,
+ uint32_t kv_size,
bool offload) {
const struct llama_hparams & hparams = model.hparams;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const int64_t n_layer = hparams.n_layer;
cache.has_shift = false;
+ // TODO: find a nicer way to add other recurrent model architectures
+ cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+
+ // TODO: support mixed reccurent Transformer architectues
+ // NOTE: (!a || b) is a logical implication (a -> b)
+ GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
+ GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
+ GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
+ GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
+
cache.head = 0;
- cache.size = n_ctx;
+ cache.size = kv_size;
cache.used = 0;
cache.type_k = type_k;
cache.type_v = type_v;
cache.cells.clear();
- cache.cells.resize(n_ctx);
+ cache.cells.resize(kv_size);
+
+ if (cache.recurrent) {
+ // init state copy sources
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ cache.cells[i].src = i;
+ }
+ }
#ifdef GGML_USE_CLBLAST
offload = false;
@@ -2080,8 +2180,8 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
@@ -2115,6 +2215,54 @@ static bool llama_kv_cache_find_slot(
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
+ if (cache.recurrent) {
+ // For recurrent state architectures (like Mamba),
+ // each KV cache cell can store the state for a whole sequence.
+
+ llama_seq_id min = cache.size - 1;
+ llama_seq_id max = 0;
+
+ for (uint32_t i = 0; i < n_tokens; ++i) {
+ for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
+ llama_seq_id seq_id = batch.seq_id[i][j];
+ // make sure it's a valid seq_id
+ if ((uint32_t) seq_id < cache.size) {
+ if (seq_id > max) {
+ max = seq_id;
+ }
+ if (seq_id < min) {
+ min = seq_id;
+ }
+ // Assuming the tokens are in-order
+ if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
+ // What should happen when the pos backtracks or skips a value?
+ // Clearing the state mid-batch would require special-casing which isn't done.
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
+ __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
+ }
+ if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
+ cache.used += 1;
+ }
+ cache.cells[seq_id].pos = batch.pos[i];
+ // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
+ } else {
+ // too big seq_id
+ // TODO: would it be possible to resize the KV cache size instead?
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+ return false;
+ }
+ }
+ }
+
+ // allow getting the range of used cells, from head to head + n
+ cache.head = min;
+ cache.n = max - min + 1;
+
+ // sanity check
+ return max >= min;
+ }
+ // otherwise, one cell per token.
+
if (n_tokens > n_ctx) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false;
@@ -2184,7 +2332,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.used = 0;
}
-static void llama_kv_cache_seq_rm(
+static bool llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
@@ -2194,6 +2342,25 @@ static void llama_kv_cache_seq_rm(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ // models like Mamba can't have a state partially erased
+ if (cache.recurrent) {
+ if (seq_id >= (int64_t) cache.size) {
+ // could be fatal
+ return false;
+ }
+ if (0 <= seq_id) {
+ // partial intersection is invalid
+ if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
+ return false;
+ }
+ } else {
+ // seq_id is negative, then the range should include everything or nothing
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+ return false;
+ }
+ }
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
@@ -2215,6 +2382,8 @@ static void llama_kv_cache_seq_rm(
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
+
+ return true;
}
static void llama_kv_cache_seq_cp(
@@ -2226,6 +2395,29 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
+ seq_id_src = cache.cells[seq_id_src].src;
+ GGML_ASSERT((uint32_t) seq_id_src < cache.size);
+ // intent to "copy from"
+ // supports copy chains thanks to taking the source of the source
+ cache.cells[seq_id_dst].src = seq_id_src;
+
+ // preserve the "keep or clear" status of the copied sequence
+ if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
+ cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
+ } else {
+ cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
+ }
+
+ cache.do_copy = true;
+
+ cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
+ }
+ return;
+ }
+ // otherwise, this is the KV cache of a Transformer-like model
+
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2265,6 +2457,17 @@ static void llama_kv_cache_seq_add(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ // for Mamba-like models, only the pos needs to be shifted
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
+ llama_kv_cell & cell = cache.cells[seq_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos += delta;
+ }
+ }
+ return;
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
@@ -2298,6 +2501,17 @@ static void llama_kv_cache_seq_div(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ // for Mamba-like models, only the pos needs to be changed
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
+ llama_kv_cell & cell = cache.cells[seq_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos /= d;
+ }
+ }
+ return;
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
@@ -3117,7 +3331,7 @@ static void llm_load_hparams(
// sanity check for n_rot (optional)
{
- hparams.n_rot = hparams.n_embd / hparams.n_head;
+ hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
@@ -3130,10 +3344,10 @@ static void llm_load_hparams(
// gpt-j n_rot = rotary_dim
}
- hparams.n_embd_head_k = hparams.n_embd / hparams.n_head;
+ hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
- hparams.n_embd_head_v = hparams.n_embd / hparams.n_head;
+ hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// arch-specific KVs
@@ -3383,6 +3597,36 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 24:
+ switch (hparams.n_embd) {
+ case 768: model.type = e_model::MODEL_SMALL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 48:
+ switch (hparams.n_embd) {
+ case 1024: model.type = e_model::MODEL_MEDIUM; break;
+ case 1536: model.type = e_model::MODEL_LARGE; break;
+ case 2048: model.type = e_model::MODEL_XL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 64:
+ switch (hparams.n_embd) {
+ case 2560: model.type = e_model::MODEL_3B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -3702,6 +3946,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
+ LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
+ LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
+ LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
+ LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
if (ml.n_elements >= 1e12) {
@@ -4609,6 +4857,57 @@ static bool llm_load_tensors(
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff});
}
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+ // only an expansion factor of 2 is supported for now
+ GGML_ASSERT(2 * n_embd == d_inner);
+
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
+ if (model.output == NULL) {
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+ ml.n_created--; // artificial tensor
+ ml.size_data += ggml_nbytes(model.output);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ // norm
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
+
+ layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
+ layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
+
+ layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
+
+ layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
+ layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
+
+ // no "weight" suffix for these
+ layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
+ layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
+
+ // out_proj
+ layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -4834,6 +5133,8 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ GGML_ASSERT(kv.size == n_ctx);
+
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
@@ -5043,6 +5344,8 @@ static struct ggml_tensor * llm_build_kqv(
cb(kq, "kq_soft_max_ext", il);
}
+ GGML_ASSERT(kv.size == n_ctx);
+
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
@@ -5190,8 +5493,8 @@ struct llm_build_context {
norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
- n_kv (worst_case ? n_ctx : kv_self.n),
- kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
+ n_kv (worst_case ? kv_self.size : kv_self.n),
+ kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@@ -5220,6 +5523,8 @@ struct llm_build_context {
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ GGML_ASSERT(kv_self.size == n_ctx);
+
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
@@ -5238,6 +5543,27 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_s_copy() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ GGML_ASSERT(kv_self.recurrent);
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
+
+ conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
+ ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
+
+ // TODO: name the intermediate tensors with cb()
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
+ }
+
+ return gf;
+ }
+
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@@ -7835,6 +8161,145 @@ struct llm_build_context {
return gf;
}
+
+ struct ggml_cgraph * build_mamba() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t d_model = n_embd;
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ GGML_ASSERT(2 * d_model == d_inner);
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ // {n_embd, n_tokens}
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+ cb(inpL, "inp_embd", -1);
+
+ struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
+ struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
+
+ for (int il = 0; il < n_layer; ++il) {
+ // (ab)using the KV cache to store the states
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
+
+ // clear states of sequences which are starting at the beginning of this batch
+ {
+ conv_states = ggml_mul(ctx0,
+ ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
+ state_mask);
+ ssm_states = ggml_mul(ctx0,
+ ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
+ state_mask);
+ }
+
+ conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
+ ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
+ struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
+ // split the above in two
+ // => {d_inner, n_tokens}
+ struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
+ struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
+
+ // conv
+ {
+ // Custom operator which is needed only to ease simultaneous sequence processing.
+ // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
+ // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weigth,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // The new conv_states is the last (d_conv - 1) columns
+ // of the last 3rd dimensional "layer" of the self-overlapping view.
+ // For simultaneous sequences, it's more complicated.
+ struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
+
+ // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
+ ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
+
+ // extract x from x_conv
+ x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
+
+ // bias
+ x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
+
+ x = ggml_silu(ctx0, x);
+ }
+
+ // ssm
+ {
+ // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
+ struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
+ // split
+ struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
+ struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
+ struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
+
+ // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
+ dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
+ dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+
+ // Custom operator to optimize the parallel associative scan
+ // as described in the Annex D of the Mamba paper.
+ // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
+ // because only a single tensor can be returned.
+ struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
+
+ // store last states (the second part of y_ssm_states)
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
+ ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
+
+ struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
+
+ // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+ y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
+
+ // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
+ cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
+ }
+
+ // residual
+ cur = ggml_add(ctx0, cur, inpL);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ // final rmsnorm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -7871,6 +8336,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
return result;
}
+static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
+ llama_batch dummy;
+ dummy.n_tokens = 0;
+
+ llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
+
+ struct llm_build_context llm(lctx, dummy, cb, false);
+
+ llm.init();
+
+ struct ggml_cgraph * result = llm.build_s_copy();
+
+ llm.free();
+
+ return result;
+}
+
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch,
@@ -7985,6 +8467,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_starcoder2();
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ result = llm.build_mamba();
+ } break;
default:
GGML_ASSERT(false);
}
@@ -7995,19 +8481,29 @@ static struct ggml_cgraph * llama_build_graph(
}
static void llama_set_k_shift(llama_context & lctx) {
- const auto & cparams = lctx.cparams;
-
- const int64_t n_ctx = cparams.n_ctx;
+ const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
- for (int i = 0; i < n_ctx; ++i) {
+ for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
+static void llama_set_s_copy(llama_context & lctx) {
+ const int64_t kv_size = lctx.kv_self.size;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
+
+ for (int i = 0; i < kv_size; ++i) {
+ data[i] = lctx.kv_self.cells[i].src;
+ }
+}
+
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
@@ -8044,6 +8540,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data;
+ // For causal attention, use only the previous KV cells
+ // of the correct sequence for each token of the batch.
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
@@ -8149,6 +8648,53 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
+
+ if (kv_self.recurrent) {
+ const int64_t n_kv = kv_self.n;
+
+ {
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
+ float * data = (float *) lctx.inp_s_mask->data;
+
+ // states which are not affected by the current batch are left untouched
+ for (int i = 0; i < n_kv; ++i) {
+ llama_seq_id seq_id = i + lctx.kv_self.head;
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
+ bool has_self_seq = kv_cell.has_seq_id(seq_id);
+
+ data[i] = (float) has_self_seq;
+
+ // ensure current sequences will be kept
+ if (!has_self_seq && kv_cell.pos >= 0) {
+ kv_cell.seq_id.insert(seq_id);
+ }
+ }
+ }
+ // For Mamba (and other recurrent architectures),
+ // update the correct state(s)/sequence(s) for each token of the batch.
+ // Like with the KQ_mask, if a token in the batch has multiple sequences,
+ // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
+ {
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
+ int32_t * data = (int32_t *) lctx.inp_s_seq->data;
+
+ for (int j = 0; j < n_tokens; ++j) {
+ const int32_t n_seq = batch.n_seq_id[j];
+ GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
+
+ for (int i = 0; i < n_kv; ++i) {
+ if (i < n_seq) {
+ // for this type of model, the head is the minimum seq_id of the batch
+ data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
+ } else {
+ data[j*n_kv + i] = -1;
+ }
+ }
+ }
+ }
+ }
}
static void llama_graph_compute(
@@ -8271,11 +8817,13 @@ static int llama_decode_internal(
return 1;
}
- // a heuristic, to avoid attending the full cache if it is not yet utilized
- // after enough generations, the benefit from this heuristic disappears
- // if we start defragmenting the cache, the benefit from this will be more important
- kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
- //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ if (!kv_self.recurrent) {
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
+ // after enough generations, the benefit from this heuristic disappears
+ // if we start defragmenting the cache, the benefit from this will be more important
+ kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+ //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ }
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -8701,6 +9249,26 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
}
}
+ if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
+ llama_set_s_copy(lctx);
+
+ {
+ ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
+
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+ }
+
+ {
+ auto & kv_self = lctx.kv_self;
+
+ kv_self.do_copy = false;
+
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
+ kv_self.cells[i].src = i;
+ }
+ }
+ }
+
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);
@@ -11535,6 +12103,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
+ // do not quantize Mamba's small yet 2D weights
+ // NOTE: can't use LLM_TN here because the layer number is not known
+ quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
+ quantize &= name.find("ssm_x.weight") == std::string::npos;
+ quantize &= name.find("ssm_dt.weight") == std::string::npos;
+
enum ggml_type new_type;
void * new_data;
size_t new_size;
@@ -11985,6 +12559,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
+ /*.n_parallel =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@@ -12146,6 +12721,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
+ // TODO: maybe add n_parallel here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -12203,8 +12779,18 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
- const ggml_type type_k = params.type_k;
- const ggml_type type_v = params.type_v;
+ uint32_t kv_size = cparams.n_ctx;
+ ggml_type type_k = params.type_k;
+ ggml_type type_v = params.type_v;
+
+ // Mamba only needs a constant number of KV cache cells per sequence
+ if (model->arch == LLM_ARCH_MAMBA) {
+ // Mamba needs at least as many KV cells as there are sequences kept at any time
+ kv_size = std::max((uint32_t) 1, params.n_parallel);
+ // it's probably best to keep as much precision as possible for the states
+ type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
+ type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
+ }
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
@@ -12304,7 +12890,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);
- if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
+ if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
@@ -12338,7 +12924,7 @@ struct llama_context * llama_new_context_with_model(
// graph inputs
{
ggml_init_params init_params = {
- /* .mem_size */ ggml_tensor_overhead()*8,
+ /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)),
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
@@ -12347,11 +12933,16 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
- ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
- ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
- ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
+ ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch);
+ ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
+ ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
+ if (ctx->kv_self.recurrent) {
+ ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
+ ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
+ ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
+ }
ggml_set_name(ctx->inp_tokens, "inp_tokens");
ggml_set_name(ctx->inp_embd, "inp_embd");
@@ -12361,6 +12952,11 @@ struct llama_context * llama_new_context_with_model(
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_mean, "inp_mean");
ggml_set_name(ctx->inp_cls, "inp_cls");
+ if (ctx->kv_self.recurrent) {
+ ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
+ ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
+ ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
+ }
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
@@ -12447,6 +13043,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
return ctx->cparams.n_batch;
}
+uint32_t llama_n_max_seq(const struct llama_context * ctx) {
+ return ctx->kv_self.size;
+}
+
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
@@ -12460,6 +13060,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_MPT:
case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM:
+ case LLM_ARCH_MAMBA:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -12713,8 +13314,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
llama_kv_cache_clear(ctx->kv_self);
}
-void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
- llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
+bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+ return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -12891,8 +13492,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@@ -12913,6 +13514,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ if (kv_self.recurrent) {
+ // v is contiguous for recurrent models
+ // TODO: use other tensors for state models than k and v
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+
+ tmp_buf.resize(v_size);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
+ data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ continue;
+ }
+
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
@@ -13005,8 +13617,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
size_t kv_buf_size;
uint32_t kv_head;
@@ -13027,6 +13639,16 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;
+ if (kv_self.recurrent) {
+ // v is contiguous for recurrent models
+ // TODO: use other tensors for state models than k and v
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+
+ ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
+ inp += v_size;
+ continue;
+ }
+
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);