diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 25 |
1 files changed, 22 insertions, 3 deletions
@@ -1744,6 +1744,7 @@ struct llama_cparams { float defrag_thold; bool embeddings; + bool causal_attn; bool offload_kqv; enum llama_pooling_type pooling_type; @@ -3939,6 +3940,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: causal attm = %d\n", __func__, hparams.causal_attn); LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); @@ -8532,7 +8534,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (hparams.causal_attn) { + GGML_ASSERT( + (hparams.causal_attn || !cparams.causal_attn) && + "non-causal attention with generative models is not supported" + ); + + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + if (cparams.causal_attn) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -8560,8 +8568,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } else { - // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) + // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; + const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -8580,7 +8589,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f; + data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; } } } @@ -12733,6 +12746,8 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.causal_attn = hparams.causal_attn; + if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; @@ -13767,6 +13782,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } +void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { + ctx->cparams.causal_attn = causal_attn; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, |