summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp95
1 files changed, 62 insertions, 33 deletions
diff --git a/llama.cpp b/llama.cpp
index 04a779e0..ed876668 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1450,7 +1450,10 @@ static bool llama_kv_cache_find_slot(
for (uint32_t i = 0; i < n_tokens; i++) {
cache.cells[cache.head + i].pos = batch.pos[i];
- cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
+
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
+ }
}
return true;
@@ -1530,6 +1533,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
+ } else {
+ cache.cells[i].seq_id.clear();
+ cache.cells[i].seq_id.insert(seq_id);
}
}
@@ -3178,7 +3184,7 @@ static struct ggml_cgraph * llm_build_llama(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3564,7 +3570,7 @@ static struct ggml_cgraph * llm_build_baichaun(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3963,7 +3969,7 @@ static struct ggml_cgraph * llm_build_refact(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4315,7 +4321,7 @@ static struct ggml_cgraph * llm_build_falcon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4667,7 +4673,7 @@ static struct ggml_cgraph * llm_build_starcoder(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4898,7 +4904,7 @@ static struct ggml_cgraph * llm_build_persimmon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
@@ -5296,7 +5302,7 @@ static struct ggml_cgraph * llm_build_bloom(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5564,7 +5570,7 @@ static struct ggml_cgraph * llm_build_mpt(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5864,8 +5870,11 @@ static int llama_decode_internal(
// helpers for smoother batch API transistion
// after deprecating the llama_eval calls, these will be removed
- std::vector<llama_pos> pos;
- std::vector<llama_seq_id> seq_id;
+ std::vector<llama_pos> pos;
+
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id *> seq_id_arr;
+ std::vector<std::vector<llama_seq_id>> seq_id;
if (batch.pos == nullptr) {
pos.resize(n_tokens);
@@ -5877,12 +5886,18 @@ static int llama_decode_internal(
}
if (batch.seq_id == nullptr) {
+ n_seq_id.resize(n_tokens);
seq_id.resize(n_tokens);
+ seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
- seq_id[i] = batch.all_seq_id;
+ n_seq_id[i] = 1;
+ seq_id[i].resize(1);
+ seq_id[i][0] = batch.all_seq_id;
+ seq_id_arr[i] = seq_id[i].data();
}
- batch.seq_id = seq_id.data();
+ batch.n_seq_id = n_seq_id.data();
+ batch.seq_id = seq_id_arr.data();
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
@@ -9109,6 +9124,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+ if (seq_id_src == seq_id_dst) {
+ return;
+ }
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
@@ -9561,7 +9579,7 @@ int llama_eval_embd(
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
- llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
+ llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
@@ -9582,20 +9600,21 @@ struct llama_batch llama_batch_get_one(
llama_pos pos_0,
llama_seq_id seq_id) {
return {
- /*n_tokens =*/ n_tokens,
- /*tokens =*/ tokens,
- /*embd =*/ nullptr,
- /*pos =*/ nullptr,
- /*seq_id =*/ nullptr,
- /*logits =*/ nullptr,
- /*all_pos_0 =*/ pos_0,
- /*all_pos_1 =*/ 1,
- /*all_seq_id =*/ seq_id,
+ /*n_tokens =*/ n_tokens,
+ /*tokens =*/ tokens,
+ /*embd =*/ nullptr,
+ /*pos =*/ nullptr,
+ /*n_seq_id =*/ nullptr,
+ /*seq_id =*/ nullptr,
+ /*logits =*/ nullptr,
+ /*all_pos_0 =*/ pos_0,
+ /*all_pos_1 =*/ 1,
+ /*all_seq_id =*/ seq_id,
};
}
-struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
- llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
+struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
+ llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
@@ -9603,19 +9622,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
- batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
+ for (int i = 0; i < n_tokens; ++i) {
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
+ }
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
- if (batch.token) free(batch.token);
- if (batch.embd) free(batch.embd);
- if (batch.pos) free(batch.pos);
- if (batch.seq_id) free(batch.seq_id);
- if (batch.logits) free(batch.logits);
+ if (batch.token) free(batch.token);
+ if (batch.embd) free(batch.embd);
+ if (batch.pos) free(batch.pos);
+ if (batch.n_seq_id) free(batch.n_seq_id);
+ if (batch.seq_id) {
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ free(batch.seq_id[i]);
+ }
+ free(batch.seq_id);
+ }
+ if (batch.logits) free(batch.logits);
}
int llama_decode(