diff options
Diffstat (limited to 'examples/baby-llama/baby-llama.cpp')
-rw-r--r-- | examples/baby-llama/baby-llama.cpp | 37 |
1 files changed, 31 insertions, 6 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index ed61125e..b02a8086 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -554,6 +554,14 @@ static struct ggml_tensor * forward( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N,1,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); for (int il = 0; il < n_layer; ++il) { @@ -581,8 +589,8 @@ static struct ggml_tensor * forward( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, 1] // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); // store key and value to memory { @@ -808,9 +816,18 @@ static struct ggml_tensor * forward_batch( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N*n_batch,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -838,8 +855,8 @@ static struct ggml_tensor * forward_batch( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, n_batch] // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); @@ -1097,6 +1114,14 @@ static struct ggml_tensor * forward_lora( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N,1,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); for (int il = 0; il < n_layer; ++il) { @@ -1130,7 +1155,7 @@ static struct ggml_tensor * forward_lora( model->layers[il].wqb, cur)), n_embd/n_head, n_head, N), - n_past, n_rot, 0, 0); + KQ_pos, n_rot, 0, 0); struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, @@ -1139,7 +1164,7 @@ static struct ggml_tensor * forward_lora( model->layers[il].wkb, cur)), n_embd/n_head, n_head, N), - n_past, n_rot, 0, 0); + KQ_pos, n_rot, 0, 0); // store key and value to memory { |