summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authords5t5 <145942675+ds5t5@users.noreply.github.com>2023-10-04 06:23:39 -0700
committerGitHub <noreply@github.com>2023-10-04 16:23:39 +0300
commitf8c90cdbaa729e64493164c1aba7ea80da7b716f (patch)
treed5d15cafc28bd2d982705c92bad32a8ea9b90e4b /llama.cpp
parentf93af02488179b9c52d0d391b08ae4c4d891b8d3 (diff)
llm : add Refact model (#3329)
* add refact model * resolve comments * rebase to the latest * solve alibi cpu error --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp382
1 files changed, 381 insertions, 1 deletions
diff --git a/llama.cpp b/llama.cpp
index a40da683..08d6c162 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -165,6 +165,7 @@ enum llm_arch {
LLM_ARCH_GPTNEOX,
LLM_ARCH_MPT,
LLM_ARCH_STARCODER,
+ LLM_ARCH_REFACT,
LLM_ARCH_UNKNOWN,
};
@@ -177,6 +178,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_MPT, "mpt" },
{ LLM_ARCH_BAICHUAN, "baichuan" },
{ LLM_ARCH_STARCODER, "starcoder" },
+ { LLM_ARCH_REFACT, "refact" },
};
enum llm_kv {
@@ -398,6 +400,23 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
},
},
{
+ LLM_ARCH_REFACT,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
+ {
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
@@ -1927,6 +1946,14 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_REFACT:
+ {
+ GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_1B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -2164,6 +2191,7 @@ static void llm_load_tensors(
const auto tn = LLM_TN(model.arch);
switch (model.arch) {
case LLM_ARCH_LLAMA:
+ case LLM_ARCH_REFACT:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
@@ -3357,6 +3385,353 @@ static struct ggml_cgraph * llm_build_baichaun(
return gf;
}
+static struct ggml_cgraph * llm_build_refact(
+ llama_context & lctx,
+ const llama_batch & batch) {
+ const auto & model = lctx.model;
+ const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
+
+ const auto & kv_self = lctx.kv_self;
+
+ GGML_ASSERT(!!kv_self.ctx);
+
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_layer = hparams.n_layer;
+ const int64_t n_ctx = cparams.n_ctx;
+ const int64_t n_head = hparams.n_head;
+ const int64_t n_head_kv = hparams.n_head_kv;
+ const int64_t n_embd_head = hparams.n_embd_head();
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
+
+ const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+ const int n_gpu_layers = model.n_gpu_layers;
+
+ const int32_t n_tokens = batch.n_tokens;
+ const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
+ const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
+
+ // printf("n_kv = %d\n", n_kv);
+
+ auto & buf_compute = lctx.buf_compute;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ buf_compute.size,
+ /*.mem_buffer =*/ buf_compute.data,
+ /*.no_alloc =*/ false,
+ };
+
+ params.no_alloc = true;
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ if (batch.token) {
+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+
+ ggml_allocr_alloc(lctx.alloc, inp_tokens);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
+ }
+ ggml_set_name(inp_tokens, "inp_tokens");
+
+ inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+ } else {
+#ifdef GGML_USE_MPI
+ GGML_ASSERT(false && "not implemented");
+#endif
+
+ inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
+
+ ggml_allocr_alloc(lctx.alloc, inpL);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
+ }
+ }
+
+ const int i_gpu_start = n_layer - n_gpu_layers;
+ (void) i_gpu_start;
+
+ // offload functions set the tensor output backend to GPU
+ // tensors are GPU-accelerated if any input or the output has been offloaded
+ offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+ offload_func_t offload_func_kq = llama_nop;
+ offload_func_t offload_func_v = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (n_gpu_layers > n_layer) {
+ offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
+ }
+ if (n_gpu_layers > n_layer + 1) {
+ offload_func_v = ggml_cuda_assign_buffers_no_alloc;
+ }
+ if (n_gpu_layers > n_layer + 2) {
+ offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
+ }
+#endif // GGML_USE_CUBLAS
+
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+ ggml_allocr_alloc(lctx.alloc, KQ_scale);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
+ }
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ offload_func_kq(KQ_mask);
+ ggml_set_name(KQ_mask, "KQ_mask");
+ ggml_allocr_alloc(lctx.alloc, KQ_mask);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ float * data = (float *) KQ_mask->data;
+ memset(data, 0, ggml_nbytes(KQ_mask));
+
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ const llama_pos pos = batch.pos[j];
+ const llama_seq_id seq_id = batch.seq_id[j];
+
+ for (int i = 0; i < n_kv; ++i) {
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+ }
+ }
+ }
+ }
+ }
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_format_name(inpL, "layer_inp_%d", il);
+
+ offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (il >= i_gpu_start) {
+ offload_func = ggml_cuda_assign_buffers_no_alloc;
+ }
+#endif // GGML_USE_CUBLAS
+
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ {
+ cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_0");
+
+ // cur = cur*attn_norm(broadcasted)
+ cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
+ offload_func(cur);
+ ggml_set_name(cur, "attention_norm_0");
+ }
+
+ // self-attention
+ {
+ // compute Q and K
+ struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ offload_func_kq(tmpk);
+ ggml_set_name(tmpk, "tmpk");
+
+ struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ offload_func_kq(tmpq);
+ ggml_set_name(tmpq, "tmpq");
+
+ struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
+ offload_func_kq(Kcur);
+ ggml_set_name(Kcur, "Kcur");
+
+ struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
+ offload_func_kq(Qcur);
+ ggml_set_name(Qcur, "Qcur");
+
+ // store key and value to memory
+ {
+ // compute the transposed [n_tokens, n_embd] V matrix
+
+ struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ offload_func_v(tmpv);
+ ggml_set_name(tmpv, "tmpv");
+
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
+ offload_func_v(Vcur);
+ ggml_set_name(Vcur, "Vcur");
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+ offload_func_kq(k);
+ ggml_set_name(k, "k");
+
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
+ ( n_ctx)*ggml_element_size(kv_self.v),
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+ offload_func_v(v);
+ ggml_set_name(v, "v");
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+ }
+
+ struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+ offload_func_kq(Q);
+ ggml_set_name(Q, "Q");
+
+ struct ggml_tensor * K =
+ ggml_view_3d(ctx0, kv_self.k,
+ n_embd_head, n_kv, n_head_kv,
+ ggml_element_size(kv_self.k)*n_embd_gqa,
+ ggml_element_size(kv_self.k)*n_embd_head,
+ ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+ offload_func_kq(K);
+ ggml_set_name(K, "K");
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ offload_func_kq(KQ);
+ ggml_set_name(KQ, "KQ");
+
+ // KQ_scaled = KQ / sqrt(n_embd_head)
+ // KQ_scaled shape [n_kv, n_tokens, n_head, 1]
+ struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+ offload_func_kq(KQ_scaled);
+ ggml_set_name(KQ_scaled, "KQ_scaled");
+
+ // KQ_masked = mask_past(KQ_scaled)
+ struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8);
+ ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+
+ struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
+ offload_func_kq(KQ_masked);
+ ggml_set_name(KQ_masked, "KQ_masked");
+
+ // KQ = soft_max(KQ_masked)
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+ offload_func_v(KQ_soft_max);
+ ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+ // split cached V into n_head heads
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, kv_self.v,
+ n_kv, n_embd_head, n_head_kv,
+ ggml_element_size(kv_self.v)*n_ctx,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+ offload_func_v(V);
+ ggml_set_name(V, "V");
+
+#if 1
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+ offload_func_v(KQV);
+ ggml_set_name(KQV, "KQV");
+#else
+ // make V contiguous in memory to speed up the matmul, however we waste time on the copy
+ // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
+ // is there a better way?
+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
+#endif
+
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+ offload_func_v(KQV_merged);
+ ggml_set_name(KQV_merged, "KQV_merged");
+
+ // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
+ offload_func_v(cur);
+ ggml_set_name(cur, "KQV_merged_contiguous");
+
+ // projection (no bias)
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].wo,
+ cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_wo");
+ }
+
+ struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+ offload_func(inpFF);
+ ggml_set_name(inpFF, "inpFF");
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_1");
+
+ // cur = cur*ffn_norm(broadcasted)
+ cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+ offload_func(cur);
+ ggml_set_name(cur, "ffn_norm");
+ }
+
+ struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+ model.layers[il].w3,
+ cur);
+ offload_func(tmp);
+ ggml_set_name(tmp, "result_w3");
+
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].w1,
+ cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_w1");
+
+ // SILU activation
+ cur = ggml_silu(ctx0, cur);
+ offload_func(cur);
+ ggml_set_name(cur, "silu");
+
+ cur = ggml_mul(ctx0, cur, tmp);
+ offload_func(cur);
+ ggml_set_name(cur, "silu_x_result_w3");
+
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].w2,
+ cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_w2");
+ }
+
+ cur = ggml_add(ctx0, cur, inpFF);
+ offload_func(cur);
+ ggml_set_name(cur, "inpFF_+_result_w2");
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
+ offload_func_nr(cur);
+ ggml_set_name(cur, "rms_norm_2");
+
+ // cur = cur*norm(broadcasted)
+ cur = ggml_mul(ctx0, cur, model.output_norm);
+ // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
+ ggml_set_name(cur, "result_norm");
+ }
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ ggml_set_name(cur, "result_output");
+
+ ggml_build_forward_expand(gf, cur);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
static struct ggml_cgraph * llm_build_falcon(
llama_context & lctx,
const llama_batch & batch) {
@@ -3997,6 +4372,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm_build_starcoder(lctx, batch);
} break;
+ case LLM_ARCH_REFACT:
+ {
+ result = llm_build_refact(lctx, batch);
+ } break;
default:
GGML_ASSERT(false);
}
@@ -4130,7 +4509,8 @@ static int llama_decode_internal(
// If all tensors can be run on the GPU then using more than 1 thread is detrimental.
const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA ||
model.arch == LLM_ARCH_BAICHUAN ||
- model.arch == LLM_ARCH_FALCON;
+ model.arch == LLM_ARCH_FALCON ||
+ model.arch == LLM_ARCH_REFACT;
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
n_threads = 1;