diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 95 |
1 files changed, 62 insertions, 33 deletions
@@ -1450,7 +1450,10 @@ static bool llama_kv_cache_find_slot( for (uint32_t i = 0; i < n_tokens; i++) { cache.cells[cache.head + i].pos = batch.pos[i]; - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } } return true; @@ -1530,6 +1533,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; + } else { + cache.cells[i].seq_id.clear(); + cache.cells[i].seq_id.insert(seq_id); } } @@ -3178,7 +3184,7 @@ static struct ggml_cgraph * llm_build_llama( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -3564,7 +3570,7 @@ static struct ggml_cgraph * llm_build_baichaun( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -3963,7 +3969,7 @@ static struct ggml_cgraph * llm_build_refact( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4315,7 +4321,7 @@ static struct ggml_cgraph * llm_build_falcon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4667,7 +4673,7 @@ static struct ggml_cgraph * llm_build_starcoder( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4898,7 +4904,7 @@ static struct ggml_cgraph * llm_build_persimmon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; @@ -5296,7 +5302,7 @@ static struct ggml_cgraph * llm_build_bloom( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -5564,7 +5570,7 @@ static struct ggml_cgraph * llm_build_mpt( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -5864,8 +5870,11 @@ static int llama_decode_internal( // helpers for smoother batch API transistion // after deprecating the llama_eval calls, these will be removed - std::vector<llama_pos> pos; - std::vector<llama_seq_id> seq_id; + std::vector<llama_pos> pos; + + std::vector<int32_t> n_seq_id; + std::vector<llama_seq_id *> seq_id_arr; + std::vector<std::vector<llama_seq_id>> seq_id; if (batch.pos == nullptr) { pos.resize(n_tokens); @@ -5877,12 +5886,18 @@ static int llama_decode_internal( } if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { - seq_id[i] = batch.all_seq_id; + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); } - batch.seq_id = seq_id.data(); + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); } if (!llama_kv_cache_find_slot(kv_self, batch)) { @@ -9109,6 +9124,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); } @@ -9561,7 +9579,7 @@ int llama_eval_embd( int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { @@ -9582,20 +9600,21 @@ struct llama_batch llama_batch_get_one( llama_pos pos_0, llama_seq_id seq_id) { return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - /*all_pos_0 =*/ pos_0, - /*all_pos_1 =*/ 1, - /*all_seq_id =*/ seq_id, + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, }; } -struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { - llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; +struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); @@ -9603,19 +9622,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return batch; } void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.seq_id) free(batch.seq_id); - if (batch.logits) free(batch.logits); + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; i < batch.n_tokens; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); } int llama_decode( |