summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp25
1 files changed, 22 insertions, 3 deletions
diff --git a/llama.cpp b/llama.cpp
index b19616e8..24944216 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1744,6 +1744,7 @@ struct llama_cparams {
float defrag_thold;
bool embeddings;
+ bool causal_attn;
bool offload_kqv;
enum llama_pooling_type pooling_type;
@@ -3939,6 +3940,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
+ LLAMA_LOG_INFO("%s: causal attm = %d\n", __func__, hparams.causal_attn);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
@@ -8532,7 +8534,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
- if (hparams.causal_attn) {
+ GGML_ASSERT(
+ (hparams.causal_attn || !cparams.causal_attn) &&
+ "non-causal attention with generative models is not supported"
+ );
+
+ // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
+ if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
@@ -8560,8 +8568,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
} else {
- // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
+ // when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
+ const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
@@ -8580,7 +8589,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
+ }
+
+ for (int i = n_tokens; i < n_stride; ++i) {
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
}
}
@@ -12733,6 +12746,8 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
+ cparams.causal_attn = hparams.causal_attn;
+
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
@@ -13767,6 +13782,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}
+void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
+ ctx->cparams.causal_attn = causal_attn;
+}
+
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,