summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--common/common.cpp1
-rwxr-xr-xconvert-hf-to-gguf.py118
-rw-r--r--examples/batched-bench/batched-bench.cpp13
-rw-r--r--examples/batched/batched.cpp3
-rw-r--r--examples/parallel/parallel.cpp20
-rw-r--r--examples/perplexity/perplexity.cpp9
-rw-r--r--examples/server/server.cpp51
-rw-r--r--ggml.c379
-rw-r--r--ggml.h19
-rw-r--r--gguf-py/gguf/constants.py41
-rw-r--r--gguf-py/gguf/gguf_writer.py12
-rw-r--r--gguf-py/gguf/tensor_mapping.py46
-rw-r--r--llama.cpp698
-rw-r--r--llama.h4
15 files changed, 1342 insertions, 74 deletions
diff --git a/README.md b/README.md
index f754022d..d7dba73e 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Recent API changes
+- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
@@ -110,6 +111,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma)
+- [x] [Mamba](https://github.com/state-spaces/mamba)
**Multimodal models:**
diff --git a/common/common.cpp b/common/common.cpp
index c244db64..d7f650ef 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1288,6 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
+ cparams.n_parallel = params.n_parallel;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py
index f6369af3..5eee3201 100755
--- a/convert-hf-to-gguf.py
+++ b/convert-hf-to-gguf.py
@@ -1847,6 +1847,124 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2
+@Model.register("MambaForCausalLM", "MambaLMHeadModel")
+class MambaModel(Model):
+ model_arch = gguf.MODEL_ARCH.MAMBA
+
+ def set_vocab(self):
+ vocab_size = self.hparams["vocab_size"]
+ # Round vocab size to next multiple of 8
+ pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8)
+ # pad using ceiling division
+ # ref: https://stackoverflow.com/a/17511341/22827863
+ vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
+ self.hparams["vocab_size"] = vocab_size
+
+ if (self.dir_model / "tokenizer.json").is_file():
+ self._set_vocab_gpt2()
+ else:
+ # Use the GPT-NeoX tokenizer when no tokenizer files are present
+ tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
+ print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
+ neox_reader = gguf.GGUFReader(tokenizer_path, "r")
+
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
+ self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
+ self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
+ self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
+ self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
+ self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
+ self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
+ field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
+ self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
+
+ def set_gguf_parameters(self):
+ d_model = self.find_hparam(["hidden_size", "d_model"])
+ d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
+ d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
+ d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
+ # ceiling division
+ # ref: https://stackoverflow.com/a/17511341/22827863
+ # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
+ dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
+ rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
+
+ # Fail early for models which don't have a block expansion factor of 2
+ assert d_inner == 2 * d_model
+
+ self.gguf_writer.add_name(self.dir_model.name)
+ self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
+ self.gguf_writer.add_embedding_length(d_model)
+ self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
+ self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
+ self.gguf_writer.add_block_count(self.hparams["n_layer"])
+ self.gguf_writer.add_ssm_conv_kernel(d_conv)
+ self.gguf_writer.add_ssm_inner_size(d_inner)
+ self.gguf_writer.add_ssm_state_size(d_state)
+ self.gguf_writer.add_ssm_time_step_rank(dt_rank)
+ self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+ self.gguf_writer.add_file_type(self.ftype)
+
+ def write_tensors(self):
+ block_count = self.hparams["n_layer"]
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+
+ tok_embd = None
+ tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight"
+ output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight"
+
+ for name, data_torch in self.get_tensors():
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ if name.endswith(".A_log"):
+ print("A_log --> A ==> " + new_name)
+ data_torch = -torch.exp(data_torch)
+
+ # assuming token_embd.weight is seen before output.weight
+ if tok_embd is not None and new_name == output_name:
+ if torch.equal(tok_embd, data_torch):
+ print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
+ continue
+ if new_name == tok_embd_name:
+ tok_embd = data_torch
+
+ data = data_torch.squeeze().numpy()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert big float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+ self.gguf_writer.add_tensor(new_name, data)
+
+
###### CONVERSION LOGIC ######
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp
index 19aff18a..dff6c68e 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/examples/batched-bench/batched-bench.cpp
@@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+ // ensure enough sequences are available
+ ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
@@ -174,10 +177,10 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch);
- const int n_tokens = is_pp_shared ? pp : pl*pp;
-
- for (int i = 0; i < n_tokens; ++i) {
- llama_batch_add(batch, 0, i, { 0 }, false);
+ for (int i = 0; i < pp; ++i) {
+ for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
+ llama_batch_add(batch, 0, i, { j }, false);
+ }
}
batch.logits[batch.n_tokens - 1] = true;
@@ -192,7 +195,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}
diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp
index 9be7eb56..dde4d5a0 100644
--- a/examples/batched/batched.cpp
+++ b/examples/batched/batched.cpp
@@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
+ ctx_params.n_parallel = n_parallel;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
@@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (int32_t i = 1; i < n_parallel; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
if (n_parallel > 1) {
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index 7d11fcd5..a2ef0fb0 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;
+ // dedicate one sequence to the system prompt
+ params.n_parallel += 1;
+
// requests to simulate
const int32_t n_seq = params.n_sequences;
@@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
}
// assign the system KV cache to all parallel sequences
- for (int32_t i = 1; i < n_clients; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
+ for (int32_t i = 1; i <= n_clients; ++i) {
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
LOG_TEE("\n");
@@ -221,15 +224,17 @@ int main(int argc, char ** argv) {
client.i_batch = batch.n_tokens;
- llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
+ llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
- for (int i = 0; i < n_clients; ++i) {
- llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
+ for (int i = 1; i <= n_clients; ++i) {
+ llama_kv_cache_seq_rm(ctx, i, -1, -1);
+ // but keep the system prompt
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
LOG_TEE("%s: clearing the KV cache\n", __func__);
@@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
- llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
+ llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
}
// extract the logits only for the last token
@@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
- llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
+ llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
+ llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 9ec98938..52789ee6 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 32;
- const int max_seq = 4*max_tasks_per_batch;
+ const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
@@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 128;
- const int max_seq = 2*max_tasks_per_batch;
+ const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
@@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
const int n_batch = params.n_batch;
const int max_tasks_per_batch = 32;
- const int max_seq = 4*max_tasks_per_batch;
+ const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
@@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
+ // ensure there's at least enough seq_ids for HellaSwag
+ params.n_parallel = std::max(4, params.n_parallel);
+
// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 109ff717..59a59d56 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -659,7 +659,11 @@ struct server_context {
bool load_model(const gpt_params & params_) {
params = params_;
+ // dedicate one sequence to the system prompt
+ params.n_parallel += 1;
+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}});
return false;
@@ -1018,8 +1022,8 @@ struct server_context {
}
// assign the system KV cache to all parallel sequences
- for (int32_t i = 1; i < params.n_parallel; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
+ for (int32_t i = 1; i <= params.n_parallel; ++i) {
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}
@@ -1306,7 +1310,7 @@ struct server_context {
const int n_embd = llama_n_embd(model);
for (int i = 0; i < batch.n_tokens; ++i) {
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
continue;
}
@@ -1633,8 +1637,8 @@ struct server_context {
{"n_cache_tokens", slot.cache_tokens.size()}
});
- llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
- llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
+ llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -1666,7 +1670,7 @@ struct server_context {
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
- llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
+ llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
slot.n_past += 1;
@@ -1804,9 +1808,6 @@ struct server_context {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
- // remove the non-common part from the cache
- slot.cache_tokens.resize(slot.n_past);
-
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
@@ -1837,8 +1838,28 @@ struct server_context {
}
}
- const int p0 = (int) system_tokens.size() + slot.n_past;
- llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
+ // keep only the common part
+ int p0 = (int) system_tokens.size() + slot.n_past;
+ if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
+ // could not partially delete (likely using a non-Transformer model)
+ llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
+
+ p0 = (int) system_tokens.size();
+ if (p0 != 0) {
+ // copy over the system prompt when there is one
+ llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
+ }
+
+ // there is no common part left (except for the system prompt)
+ slot.n_past = 0;
+ slot.n_past_se = 0;
+ slot.ga_i = 0;
+ // TODO: is the system prompt ever in the sampling context?
+ llama_sampling_reset(slot.ctx_sampling);
+ }
+
+ // remove the non-common part from the cache
+ slot.cache_tokens.resize(slot.n_past);
LOG_INFO("kv cache rm [p0, end)", {
{ "id_slot", slot.id },
@@ -1863,7 +1884,7 @@ struct server_context {
}
}
- llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
+ llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -1937,9 +1958,9 @@ struct server_context {
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
- llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
- llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
- llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
+ llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;
diff --git a/ggml.c b/ggml.c
index 92b17ee6..6eff98ab 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1841,6 +1841,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"FLASH_ATTN",
"FLASH_FF",
"FLASH_ATTN_BACK",
+ "SSM_CONV",
+ "SSM_SCAN",
"WIN_PART",
"WIN_UNPART",
"GET_REL_POS",
@@ -1863,7 +1865,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1929,6 +1931,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_attn(x)",
"flash_ff(x)",
"flash_attn_back(x)",
+ "ssm_conv(x)",
+ "ssm_scan(x)",
"win_part(x)",
"win_unpart(x)",
"get_rel_pos(x)",
@@ -1951,7 +1955,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -6154,6 +6158,108 @@ struct ggml_tensor * ggml_flash_attn_back(
return result;
}
+// ggml_ssm_conv
+
+struct ggml_tensor * ggml_ssm_conv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * c,
+ struct ggml_tensor * sq) {
+ GGML_ASSERT(ggml_is_3d(s));
+ GGML_ASSERT(ggml_is_matrix(x));
+ GGML_ASSERT(ggml_is_matrix(c));
+ GGML_ASSERT(ggml_is_matrix(sq));
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
+
+ const int64_t d_conv = c->ne[0];
+ const int64_t d_inner = c->ne[1];
+ const int64_t n_tokens = x->ne[1];
+ const int64_t n_kv = s->ne[2];
+
+ GGML_ASSERT( s->ne[0] == d_conv - 1);
+ GGML_ASSERT( s->ne[1] == d_inner);
+ GGML_ASSERT( x->ne[0] == d_inner);
+ GGML_ASSERT(sq->ne[0] == n_kv);
+ GGML_ASSERT(sq->ne[1] == n_tokens);
+
+ bool is_node = false;
+
+ if (s->grad || x->grad || c->grad || sq->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
+
+ result->op = GGML_OP_SSM_CONV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = s;
+ result->src[1] = x;
+ result->src[2] = c;
+ result->src[3] = sq;
+
+ return result;
+}
+
+// ggml_ssm_scan
+
+struct ggml_tensor * ggml_ssm_scan(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * dt,
+ struct ggml_tensor * A,
+ struct ggml_tensor * B,
+ struct ggml_tensor * C,
+ struct ggml_tensor * sq) {
+ GGML_ASSERT(ggml_is_contiguous(s));
+ GGML_ASSERT(ggml_is_contiguous(x));
+ GGML_ASSERT(ggml_is_contiguous(dt));
+ GGML_ASSERT(ggml_is_contiguous(A));
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
+ GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
+ GGML_ASSERT(ggml_are_same_shape(x, dt));
+
+ {
+ const int64_t d_state = s->ne[0];
+ const int64_t d_inner = s->ne[1];
+ const int64_t n_tokens = x->ne[1];
+
+ GGML_ASSERT(x->ne[0] == d_inner);
+ GGML_ASSERT(A->ne[0] == d_state);
+ GGML_ASSERT(A->ne[1] == d_inner);
+ GGML_ASSERT(B->ne[0] == d_state);
+ GGML_ASSERT(B->ne[1] == n_tokens);
+ GGML_ASSERT(C->ne[0] == d_state);
+ GGML_ASSERT(C->ne[1] == n_tokens);
+ }
+
+ bool is_node = false;
+
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+
+ result->op = GGML_OP_SSM_SCAN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = s;
+ result->src[1] = x;
+ result->src[2] = dt;
+ result->src[3] = A;
+ result->src[4] = B;
+ result->src[5] = C;
+ result->src[6] = sq;
+
+ return result;
+}
+
// ggml_win_part
struct ggml_tensor * ggml_win_part(
@@ -14771,6 +14877,257 @@ static void ggml_compute_forward_flash_attn_back(
}
}
+// ggml_compute_forward_ssm_conv
+
+static void ggml_compute_forward_ssm_conv_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_state
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
+ const struct ggml_tensor * src3 = dst->src[3]; // state_seq
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src2->ne[0]; // d_conv
+ const int nr = src0->ne[1]; // d_inner
+ const int n_t = src1->ne[1]; // n_tokens
+ const int n_kv = src0->ne[2]; // max number of sequences in the batch
+
+ GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+ // for use with the destination state offset between sequences
+ GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+ const int ir = ir1 - ir0;
+
+ if (n_kv > 1) {
+ // multiple sequences means it's hard to know when it's the first time a state is read,
+ // so copy them all over to the destination, just to be sure.
+ for (int i3 = 0; i3 < n_kv; ++i3) {
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
+ // can't use memcpy because of d_conv vs d_conv - 1
+ for (int i1 = 0; i1 < ir; ++i1) {
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
+ // copy s0 to last (d_conv - 1) columns of s
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
+ }
+ }
+ }
+ }
+
+ for (int i2 = 0; i2 < n_t; ++i2) {
+ int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
+ float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
+ float * s0; // {d_conv - 1, d_inner, n_kv}
+ float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
+ int ne0s0;
+
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
+
+ // avoid needing to copy the state for the first token
+ if (i2 == 0) {
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
+ ne0s0 = src0->ne[0];
+ } else {
+ // the source is the last (d_conv - 1) columns of the destination
+ s0 = s + 1;
+ ne0s0 = nc;
+ }
+
+ // d_inner
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // shift state left
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
+ s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
+ }
+ // insert x on the last column
+ s[(nc - 1) + i1*nc] = x0[i1];
+ }
+
+ // handle copies when there are multiple output states
+ for (int i3 = 1; i3 < n_kv; ++i3) {
+ int32_t seq = sq[i3];
+ if (0 <= seq && seq < n_kv) {
+ float * s1 = s + (seq - sq[0])*nc*nr;
+ memcpy(s1, s, nc*ir*sizeof(float));
+ } else {
+ // stop at negative or too big seq_ids
+ break;
+ }
+ }
+
+ // it seems a little faster when this is separate from the state shift
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // rowwise dot product
+ float sumf = 0.0f;
+ for (int i0 = 0; i0 < nc; ++i0) {
+ int i = i0 + i1*nc;
+ sumf += s[i] * c[i];
+ }
+ x[i1] = sumf;
+ }
+ }
+}
+
+static void ggml_compute_forward_ssm_conv(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->src[0]->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_ssm_conv_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_ssm_scan
+
+static void ggml_compute_forward_ssm_scan_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ const struct ggml_tensor * src0 = dst->src[0]; // s
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
+ const struct ggml_tensor * src3 = dst->src[3]; // A
+ const struct ggml_tensor * src4 = dst->src[4]; // B
+ const struct ggml_tensor * src5 = dst->src[5]; // C
+ const struct ggml_tensor * src6 = dst->src[6]; // sq
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nc = src0->ne[0]; // d_state
+ const int64_t nr = src0->ne[1]; // d_inner
+ const int64_t n_t = src1->ne[1]; // number of tokens in the batch
+ const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
+
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
+ // required for the dot product between s and C, and when copying the states
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+ // required for per-sequence offsets for states
+ GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
+ // required to get correct offset for state destination (i.e. src1->nb[2])
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+ const int ir = ir1 - ir0;
+
+ if (n_kv > 1) {
+ // it's hard to know if the source states have already been copied
+ // when there are multiple, so copy them already.
+ for (int i3 = 0; i3 < n_kv; ++i3) {
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
+ memcpy(s, s0, nc*ir*sizeof(float));
+ }
+ }
+
+ for (int i2 = 0; i2 < n_t; ++i2) {
+ int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
+ float * s0;
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
+
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
+
+ // avoid needing to copy the state for the first token
+ if (i2 == 0) {
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
+ } else {
+ // otherwise the source is the same as the destination
+ s0 = s;
+ }
+
+ // d_inner
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+ float x_dt = x[i1] * dt_soft_plus;
+ float sumf = 0.0f;
+ // d_state
+ for (int i0 = 0; i0 < nc; ++i0) {
+ int i = i0 + i1*nc;
+ // state = prev_state * dA + dB * x
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[i0];
+ s[i] = state;
+ }
+ y[i1] = sumf;
+ }
+
+ // handle copies when there are multiple output states
+ for (int i3 = 1; i3 < n_kv; ++i3) {
+ int32_t seq = sq[i3];
+ if (0 <= seq && seq < n_kv) {
+ float * s1 = s + (seq - sq[0])*nc*nr;
+ memcpy(s1, s, nc*ir*sizeof(float));
+ } else {
+ // stop at negative or too big seq_ids
+ break;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_ssm_scan(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->src[0]->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_ssm_scan_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_win_part
static void ggml_compute_forward_win_part_f32(
@@ -15830,6 +16187,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
bool masked = t != 0;
ggml_compute_forward_flash_attn_back(params, masked, tensor);
} break;
+ case GGML_OP_SSM_CONV:
+ {
+ ggml_compute_forward_ssm_conv(params, tensor);
+ } break;
+ case GGML_OP_SSM_SCAN:
+ {
+ ggml_compute_forward_ssm_scan(params, tensor);
+ } break;
case GGML_OP_WIN_PART:
{
ggml_compute_forward_win_part(params, tensor);
@@ -16884,6 +17249,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // not supported
} break;
+ case GGML_OP_SSM_CONV:
+ case GGML_OP_SSM_SCAN:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_UNARY:
@@ -17590,6 +17960,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = n_threads;
} break;
+ case GGML_OP_SSM_CONV:
+ case GGML_OP_SSM_SCAN:
+ {
+ n_tasks = n_threads;
+ } break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
diff --git a/ggml.h b/ggml.h
index 0ea4f884..a13b0cec 100644
--- a/ggml.h
+++ b/ggml.h
@@ -472,6 +472,8 @@ extern "C" {
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK,
+ GGML_OP_SSM_CONV,
+ GGML_OP_SSM_SCAN,
GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
@@ -1728,6 +1730,23 @@ extern "C" {
struct ggml_tensor * c0,
struct ggml_tensor * c1);
+ GGML_API struct ggml_tensor * ggml_ssm_conv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * c,
+ struct ggml_tensor * sq);
+
+ GGML_API struct ggml_tensor * ggml_ssm_scan(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * dt,
+ struct ggml_tensor * A,
+ struct ggml_tensor * B,
+ struct ggml_tensor * C,
+ struct ggml_tensor * sq);
+
// partition into non-overlapping windows with padding if needed
// example:
// a: 768 64 64 1
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index a6213981..b23badb1 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -61,6 +61,12 @@ class Keys:
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
+ class SSM:
+ CONV_KERNEL = "{arch}.ssm.conv_kernel"
+ INNER_SIZE = "{arch}.ssm.inner_size"
+ STATE_SIZE = "{arch}.ssm.state_size"
+ TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
+
class Tokenizer:
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
@@ -113,6 +119,7 @@ class MODEL_ARCH(IntEnum):
MINICPM = auto()
GEMMA = auto()
STARCODER2 = auto()
+ MAMBA = auto()
class MODEL_TENSOR(IntEnum):
@@ -144,6 +151,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
+ SSM_IN = auto()
+ SSM_CONV1D = auto()
+ SSM_X = auto()
+ SSM_DT = auto()
+ SSM_A = auto()
+ SSM_D = auto()
+ SSM_OUT = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -171,6 +185,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.STARCODER2: "starcoder2",
+ MODEL_ARCH.MAMBA: "mamba",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -202,6 +217,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
+ MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
+ MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
+ MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
+ MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
+ MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
+ MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
+ MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -543,6 +565,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.MAMBA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_X,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_OUT,
+ ],
# TODO
}
@@ -734,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
+# SSM
+KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
+KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
+KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
+KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
+
# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 80116083..e49c5db6 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -382,6 +382,18 @@ class GGUFWriter:
def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
+ def add_ssm_conv_kernel(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
+
+ def add_ssm_inner_size(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
+
+ def add_ssm_state_size(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
+
+ def add_ssm_time_step_rank(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
+
def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index db2ec970..ed89955d 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -20,6 +20,9 @@ class TensorNameMap:
"wte", # gpt2
"transformer.embd.wte", # phi2
"model.tok_embeddings", # internlm2
+ "model.embedding", # mamba-qbert
+ "backbone.embedding", # mamba
+ "backbone.embeddings", # mamba-hf
),
# Token type embeddings
@@ -44,7 +47,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@@ -61,6 +64,8 @@ class TensorNameMap:
"language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
"lm_head.ln", # phi2
+ "model.norm_f", # mamba-qbert
+ "backbone.norm_f", # mamba
),
# Rope frequencies
@@ -86,6 +91,8 @@ class TensorNameMap:
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
"model.layers.{bid}.attention_norm", # internlm2
+ "model.layers.{bid}.norm", # mamba-qbert
+ "backbone.layers.{bid}.norm", # mamba
),
# Attention norm 2
@@ -282,7 +289,42 @@ class TensorNameMap:
MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm", # bert
"encoder.layers.{bid}.norm2", # nomic-bert
- )
+ ),
+
+ MODEL_TENSOR.SSM_IN: (
+ "model.layers.{bid}.in_proj",
+ "backbone.layers.{bid}.mixer.in_proj",
+ ),
+
+ MODEL_TENSOR.SSM_CONV1D: (
+ "model.layers.{bid}.conv1d",
+ "backbone.layers.{bid}.mixer.conv1d",
+ ),
+
+ MODEL_TENSOR.SSM_X: (
+ "model.layers.{bid}.x_proj",
+ "backbone.layers.{bid}.mixer.x_proj",
+ ),
+
+ MODEL_TENSOR.SSM_DT: (
+ "model.layers.{bid}.dt_proj",
+ "backbone.layers.{bid}.mixer.dt_proj",
+ ),
+
+ MODEL_TENSOR.SSM_A: (
+ "model.layers.{bid}.A_log",
+ "backbone.layers.{bid}.mixer.A_log",
+ ),
+
+ MODEL_TENSOR.SSM_D: (
+ "model.layers.{bid}.D",
+ "backbone.layers.{bid}.mixer.D",
+ ),
+
+ MODEL_TENSOR.SSM_OUT: (
+ "model.layers.{bid}.out_proj",
+ "backbone.layers.{bid}.mixer.out_proj",
+ ),
}
mapping: dict[str, tuple[MODEL_TENSOR, str]]
diff --git a/llama.cpp b/llama.cpp
index 4a20b792..8c147a42 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -213,6 +213,7 @@ enum llm_arch {
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_STARCODER2,
+ LLM_ARCH_MAMBA,
LLM_ARCH_UNKNOWN,
};
@@ -241,6 +242,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
+ { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -284,6 +286,11 @@ enum llm_kv {
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
+ LLM_KV_SSM_INNER_SIZE,
+ LLM_KV_SSM_CONV_KERNEL,
+ LLM_KV_SSM_STATE_SIZE,
+ LLM_KV_SSM_TIME_STEP_RANK,
+
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST,
LLM_KV_TOKENIZER_TOKEN_TYPE,
@@ -342,6 +349,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
+ { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
+ { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
+ { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
+ { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
+
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
@@ -399,6 +411,13 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
+ LLM_TENSOR_SSM_IN,
+ LLM_TENSOR_SSM_CONV1D,
+ LLM_TENSOR_SSM_X,
+ LLM_TENSOR_SSM_DT,
+ LLM_TENSOR_SSM_A,
+ LLM_TENSOR_SSM_D,
+ LLM_TENSOR_SSM_OUT,
};
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -802,6 +821,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
},
},
{
+ LLM_ARCH_MAMBA,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
+ },
+ },
+ {
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -1613,6 +1648,12 @@ struct llama_hparams {
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
+ // for State Space Models
+ uint32_t ssm_d_conv = 0;
+ uint32_t ssm_d_inner = 0;
+ uint32_t ssm_d_state = 0;
+ uint32_t ssm_dt_rank = 0;
+
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
@@ -1641,6 +1682,11 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
+ if (this->ssm_d_conv != other.ssm_d_conv) return true;
+ if (this->ssm_d_inner != other.ssm_d_inner) return true;
+ if (this->ssm_d_state != other.ssm_d_state) return true;
+ if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
+
const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
@@ -1652,6 +1698,9 @@ struct llama_hparams {
}
uint32_t n_gqa() const {
+ if (n_head_kv == 0) {
+ return 0;
+ }
return n_head/n_head_kv;
}
@@ -1662,6 +1711,18 @@ struct llama_hparams {
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
return n_embd_head_v * n_head_kv;
}
+
+ uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
+ // corresponds to Mamba's conv_states size
+ // TODO: maybe support other convolution strides than 1
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+ }
+
+ uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
+ // corresponds to Mamba's ssm_states size
+ return ssm_d_state * ssm_d_inner;
+ }
};
struct llama_cparams {
@@ -1739,11 +1800,27 @@ struct llama_layer {
struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3
struct ggml_tensor * ffn_act;
+
+ // mamba proj
+ struct ggml_tensor * ssm_in;
+ struct ggml_tensor * ssm_x;
+ struct ggml_tensor * ssm_dt;
+ struct ggml_tensor * ssm_out;
+
+ // mamba
+ struct ggml_tensor * ssm_conv1d;
+ struct ggml_tensor * ssm_a;
+ struct ggml_tensor * ssm_d;
+
+ // mamba bias
+ struct ggml_tensor * ssm_conv1d_b;
+ struct ggml_tensor * ssm_dt_b;
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
+ int32_t src = 0; // used by recurrent state models to copy states
std::set<llama_seq_id> seq_id;
@@ -1764,6 +1841,9 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
+ bool do_copy = false;
+ // with recurrent state models, a cell can hold the state for more than one past token
+ bool recurrent = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
@@ -2003,11 +2083,14 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
- struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
- struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
- struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
+ struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
+ struct ggml_tensor * inp_s_mask; // F32 [kv_size]
+ struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
@@ -2023,25 +2106,42 @@ static bool llama_kv_cache_init(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
- uint32_t n_ctx,
+ uint32_t kv_size,
bool offload) {
const struct llama_hparams & hparams = model.hparams;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const int64_t n_layer = hparams.n_layer;
cache.has_shift = false;
+ // TODO: find a nicer way to add other recurrent model architectures
+ cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+
+ // TODO: support mixed reccurent Transformer architectues
+ // NOTE: (!a || b) is a logical implication (a -> b)
+ GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
+ GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
+ GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
+ GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
+
cache.head = 0;
- cache.size = n_ctx;
+ cache.size = kv_size;
cache.used = 0;
cache.type_k = type_k;
cache.type_v = type_v;
cache.cells.clear();
- cache.cells.resize(n_ctx);
+ cache.cells.resize(kv_size);
+
+ if (cache.recurrent) {
+ // init state copy sources
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ cache.cells[i].src = i;
+ }
+ }
#ifdef GGML_USE_CLBLAST
offload = false;
@@ -2080,8 +2180,8 @@ static bool llama_kv_cache_init(
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
@@ -2115,6 +2215,54 @@ static bool llama_kv_cache_find_slot(
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
+ if (cache.recurrent) {
+ // For recurrent state architectures (like Mamba),
+ // each KV cache cell can store the state for a whole sequence.
+
+ llama_seq_id min = cache.size - 1;
+ llama_seq_id max = 0;
+
+ for (uint32_t i = 0; i < n_tokens; ++i) {
+ for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
+ llama_seq_id seq_id = batch.seq_id[i][j];
+ // make sure it's a valid seq_id
+ if ((uint32_t) seq_id < cache.size) {
+ if (seq_id > max) {
+ max = seq_id;
+ }
+ if (seq_id < min) {
+ min = seq_id;
+ }
+ // Assuming the tokens are in-order
+ if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
+ // What should happen when the pos backtracks or skips a value?
+ // Clearing the state mid-batch would require special-casing which isn't done.
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
+ __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
+ }
+ if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
+ cache.used += 1;
+ }
+ cache.cells[seq_id].pos = batch.pos[i];
+ // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
+ } else {
+ // too big seq_id
+ // TODO: would it be possible to resize the KV cache size instead?
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+ return false;
+ }
+ }
+ }
+
+ // allow getting the range of used cells, from head to head + n
+ cache.head = min;
+ cache.n = max - min + 1;
+
+ // sanity check
+ return max >= min;
+ }
+ // otherwise, one cell per token.
+
if (n_tokens > n_ctx) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false;
@@ -2184,7 +2332,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.used = 0;
}
-static void llama_kv_cache_seq_rm(
+static bool llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
@@ -2194,6 +2342,25 @@ static void llama_kv_cache_seq_rm(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ // models like Mamba can't have a state partially erased
+ if (cache.recurrent) {
+ if (seq_id >= (int64_t) cache.size) {
+ // could be fatal
+ return false;
+ }
+ if (0 <= seq_id) {
+ // partial intersection is invalid
+ if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
+ return false;
+ }
+ } else {
+ // seq_id is negative, then the range should include everything or nothing
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+ return false;
+ }
+ }
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
@@ -2215,6 +2382,8 @@ static void llama_kv_cache_seq_rm(
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
+
+ return true;
}
static void llama_kv_cache_seq_cp(
@@ -2226,6 +2395,29 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
+ seq_id_src = cache.cells[seq_id_src].src;
+ GGML_ASSERT((uint32_t) seq_id_src < cache.size);
+ // intent to "copy from"
+ // supports copy chains thanks to taking the source of the source
+ cache.cells[seq_id_dst].src = seq_id_src;
+
+ // preserve the "keep or clear" status of the copied sequence
+ if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
+ cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
+ } else {
+ cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
+ }
+
+ cache.do_copy = true;
+
+ cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
+ }
+ return;
+ }
+ // otherwise, this is the KV cache of a Transformer-like model
+
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2265,6 +2457,17 @@ static void llama_kv_cache_seq_add(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ // for Mamba-like models, only the pos needs to be shifted
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
+ llama_kv_cell & cell = cache.cells[seq_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos += delta;
+ }
+ }
+ return;
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
@@ -2298,6 +2501,17 @@ static void llama_kv_cache_seq_div(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ // for Mamba-like models, only the pos needs to be changed
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
+ llama_kv_cell & cell = cache.cells[seq_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos /= d;
+ }
+ }
+ return;
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
@@ -3117,7 +3331,7 @@ static void llm_load_hparams(
// sanity check for n_rot (optional)
{
- hparams.n_rot = hparams.n_embd / hparams.n_head;
+ hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
@@ -3130,10 +3344,10 @@ static void llm_load_hparams(
// gpt-j n_rot = rotary_dim
}
- hparams.n_embd_head_k = hparams.n_embd / hparams.n_head;
+ hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
- hparams.n_embd_head_v = hparams.n_embd / hparams.n_head;
+ hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// arch-specific KVs
@@ -3383,6 +3597,36 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 24:
+ switch (hparams.n_embd) {
+ case 768: model.type = e_model::MODEL_SMALL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 48:
+ switch (hparams.n_embd) {
+ case 1024: model.type = e_model::MODEL_MEDIUM; break;
+ case 1536: model.type = e_model::MODEL_LARGE; break;
+ case 2048: model.type = e_model::MODEL_XL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 64:
+ switch (hparams.n_embd) {
+ case 2560: model.type = e_model::MODEL_3B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -3702,6 +3946,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
+ LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
+ LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
+ LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
+ LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
if (ml.n_elements >= 1e12) {
@@ -4609,6 +4857,57 @@ static bool llm_load_tensors(
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff});
}
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+ // only an expansion factor of 2 is supported for now
+ GGML_ASSERT(2 * n_embd == d_inner);
+
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
+ if (model.output == NULL) {
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+ ml.n_created--; // artificial tensor
+ ml.size_data += ggml_nbytes(model.output);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ // norm
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
+
+ layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
+ layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
+
+ layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
+
+ layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
+ layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
+
+ // no "weight" suffix for these
+ layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
+ layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
+
+ // out_proj
+ layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -4834,6 +5133,8 @@ static void llm_build_kv_store(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ GGML_ASSERT(kv.size == n_ctx);
+
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
@@ -5043,6 +5344,8 @@ static struct ggml_tensor * llm_build_kqv(
cb(kq, "kq_soft_max_ext", il);
}
+ GGML_ASSERT(kv.size == n_ctx);
+
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
@@ -5190,8 +5493,8 @@ struct llm_build_context {
norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
- n_kv (worst_case ? n_ctx : kv_self.n),
- kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
+ n_kv (worst_case ? kv_self.size : kv_self.n),
+ kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@@ -5220,6 +5523,8 @@ struct llm_build_context {
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ GGML_ASSERT(kv_self.size == n_ctx);
+
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
@@ -5238,6 +5543,27 @@ struct llm_build_context {
return gf;
}
+ struct ggml_cgraph * build_s_copy() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ GGML_ASSERT(kv_self.recurrent);
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
+
+ conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
+ ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
+
+ // TODO: name the intermediate tensors with cb()
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
+ }
+
+ return gf;
+ }
+
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@@ -7835,6 +8161,145 @@ struct llm_build_context {
return gf;
}
+
+ struct ggml_cgraph * build_mamba() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t d_model = n_embd;
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ GGML_ASSERT(2 * d_model == d_inner);
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ // {n_embd, n_tokens}
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+ cb(inpL, "inp_embd", -1);
+
+ struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
+ struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
+
+ for (int il = 0; il < n_layer; ++il) {
+ // (ab)using the KV cache to store the states
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
+
+ // clear states of sequences which are starting at the beginning of this batch
+ {
+ conv_states = ggml_mul(ctx0,
+ ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
+ state_mask);
+ ssm_states = ggml_mul(ctx0,
+ ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
+ state_mask);
+ }
+
+ conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
+ ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
+ struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
+ // split the above in two
+ // => {d_inner, n_tokens}
+ struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
+ struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
+
+ // conv
+ {
+ // Custom operator which is needed only to ease simultaneous sequence processing.
+ // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
+ // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weigth,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // The new conv_states is the last (d_conv - 1) columns
+ // of the last 3rd dimensional "layer" of the self-overlapping view.
+ // For simultaneous sequences, it's more complicated.
+ struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
+
+ // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
+ ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
+
+ // extract x from x_conv
+ x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
+
+ // bias
+ x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
+
+ x = ggml_silu(ctx0, x);
+ }
+
+ // ssm
+ {
+ // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
+ struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
+ // split
+ struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
+ struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
+ struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
+
+ // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
+ dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
+ dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+
+ // Custom operator to optimize the parallel associative scan
+ // as described in the Annex D of the Mamba paper.
+ // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
+ // because only a single tensor can be returned.
+ struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
+
+ // store last states (the second part of y_ssm_states)
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
+ ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
+
+ struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
+
+ // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+ y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
+
+ // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
+ cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
+ }
+
+ // residual
+ cur = ggml_add(ctx0, cur, inpL);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ // final rmsnorm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -7871,6 +8336,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
return result;
}
+static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
+ llama_batch dummy;
+ dummy.n_tokens = 0;
+
+ llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
+
+ struct llm_build_context llm(lctx, dummy, cb, false);
+
+ llm.init();
+
+ struct ggml_cgraph * result = llm.build_s_copy();
+
+ llm.free();
+
+ return result;
+}
+
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch,
@@ -7985,6 +8467,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_starcoder2();
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ result = llm.build_mamba();
+ } break;
default:
GGML_ASSERT(false);
}
@@ -7995,19 +8481,29 @@ static struct ggml_cgraph * llama_build_graph(
}
static void llama_set_k_shift(llama_context & lctx) {
- const auto & cparams = lctx.cparams;
-
- const int64_t n_ctx = cparams.n_ctx;
+ const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
- for (int i = 0; i < n_ctx; ++i) {
+ for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
+static void llama_set_s_copy(llama_context & lctx) {
+ const int64_t kv_size = lctx.kv_self.size;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
+
+ for (int i = 0; i < kv_size; ++i) {
+ data[i] = lctx.kv_self.cells[i].src;
+ }
+}
+
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
@@ -8044,6 +8540,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = (float *) lctx.inp_KQ_mask->data;
+ // For causal attention, use only the previous KV cells
+ // of the correct sequence for each token of the batch.
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
@@ -8149,6 +8648,53 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
+
+ if (kv_self.recurrent) {
+ const int64_t n_kv = kv_self.n;
+
+ {
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
+ float * data = (float *) lctx.inp_s_mask->data;
+
+ // states which are not affected by the current batch are left untouched
+ for (int i = 0; i < n_kv; ++i) {
+ llama_seq_id seq_id = i + lctx.kv_self.head;
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
+ bool has_self_seq = kv_cell.has_seq_id(seq_id);
+
+ data[i] = (float) has_self_seq;
+
+ // ensure current sequences will be kept
+ if (!has_self_seq && kv_cell.pos >= 0) {
+ kv_cell.seq_id.insert(seq_id);
+ }
+ }
+ }
+ // For Mamba (and other recurrent architectures),
+ // update the correct state(s)/sequence(s) for each token of the batch.
+ // Like with the KQ_mask, if a token in the batch has multiple sequences,
+ // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
+ {
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
+ int32_t * data = (int32_t *) lctx.inp_s_seq->data;
+
+ for (int j = 0; j < n_tokens; ++j) {
+ const int32_t n_seq = batch.n_seq_id[j];
+ GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
+
+ for (int i = 0; i < n_kv; ++i) {
+ if (i < n_seq) {
+ // for this type of model, the head is the minimum seq_id of the batch
+ data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
+ } else {
+ data[j*n_kv + i] = -1;
+ }
+ }
+ }
+ }
+ }
}
static void llama_graph_compute(
@@ -8271,11 +8817,13 @@ static int llama_decode_internal(
return 1;
}
- // a heuristic, to avoid attending the full cache if it is not yet utilized
- // after enough generations, the benefit from this heuristic disappears
- // if we start defragmenting the cache, the benefit from this will be more important
- kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
- //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ if (!kv_self.recurrent) {
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
+ // after enough generations, the benefit from this heuristic disappears
+ // if we start defragmenting the cache, the benefit from this will be more important
+ kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+ //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ }
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -8701,6 +9249,26 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
}
}
+ if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
+ llama_set_s_copy(lctx);
+
+ {
+ ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
+
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+ }
+
+ {
+ auto & kv_self = lctx.kv_self;
+
+ kv_self.do_copy = false;
+
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
+ kv_self.cells[i].src = i;
+ }
+ }
+ }
+
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);
@@ -11535,6 +12103,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
+ // do not quantize Mamba's small yet 2D weights
+ // NOTE: can't use LLM_TN here because the layer number is not known
+ quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
+ quantize &= name.find("ssm_x.weight") == std::string::npos;
+ quantize &= name.find("ssm_dt.weight") == std::string::npos;
+
enum ggml_type new_type;
void * new_data;
size_t new_size;
@@ -11985,6 +12559,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
+ /*.n_parallel =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@@ -12146,6 +12721,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
+ // TODO: maybe add n_parallel here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -12203,8 +12779,18 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
- const ggml_type type_k = params.type_k;
- const ggml_type type_v = params.type_v;
+ uint32_t kv_size = cparams.n_ctx;
+ ggml_type type_k = params.type_k;
+ ggml_type type_v = params.type_v;
+
+ // Mamba only needs a constant number of KV cache cells per sequence
+ if (model->arch == LLM_ARCH_MAMBA) {
+ // Mamba needs at least as many KV cells as there are sequences kept at any time
+ kv_size = std::max((uint32_t) 1, params.n_parallel);
+ // it's probably best to keep as much precision as possible for the states
+ type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
+ type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
+ }
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
@@ -12304,7 +12890,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);
- if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
+ if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
@@ -12338,7 +12924,7 @@ struct llama_context * llama_new_context_with_model(
// graph inputs
{
ggml_init_params init_params = {
- /* .mem_size */ ggml_tensor_overhead()*8,
+ /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)),
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
@@ -12347,11 +12933,16 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
- ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
- ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
- ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
+ ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch);
+ ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
+ ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
+ if (ctx->kv_self.recurrent) {
+ ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
+ ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
+ ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
+ }
ggml_set_name(ctx->inp_tokens, "inp_tokens");
ggml_set_name(ctx->inp_embd, "inp_embd");
@@ -12361,6 +12952,11 @@ struct llama_context * llama_new_context_with_model(
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_mean, "inp_mean");
ggml_set_name(ctx->inp_cls, "inp_cls");
+ if (ctx->kv_self.recurrent) {
+ ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
+ ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
+ ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
+ }
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
@@ -12447,6 +13043,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
return ctx->cparams.n_batch;
}
+uint32_t llama_n_max_seq(const struct llama_context * ctx) {
+ return ctx->kv_self.size;
+}
+
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
@@ -12460,6 +13060,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_MPT:
case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM:
+ case LLM_ARCH_MAMBA:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -12713,8 +13314,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
llama_kv_cache_clear(ctx->kv_self);
}
-void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
- llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
+bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+ return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -12891,8 +13492,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@@ -12913,6 +13514,17 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ if (kv_self.recurrent) {
+ // v is contiguous for recurrent models
+ // TODO: use other tensors for state models than k and v
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+
+ tmp_buf.resize(v_size);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
+ data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ continue;
+ }
+
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
@@ -13005,8 +13617,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
size_t kv_buf_size;
uint32_t kv_head;
@@ -13027,6 +13639,16 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;
+ if (kv_self.recurrent) {
+ // v is contiguous for recurrent models
+ // TODO: use other tensors for state models than k and v
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+
+ ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
+ inp += v_size;
+ continue;
+ }
+
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
diff --git a/llama.h b/llama.h
index 3dc162b0..7a107c7f 100644
--- a/llama.h
+++ b/llama.h
@@ -235,6 +235,7 @@ extern "C" {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // prompt processing maximum batch size
+ uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
@@ -376,6 +377,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
+ LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@@ -502,7 +504,7 @@ extern "C" {
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
- LLAMA_API void llama_kv_cache_seq_rm(
+ LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,