summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp181
1 files changed, 181 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index 47b4384a..1cee5a79 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -194,6 +194,7 @@ enum llm_arch {
LLM_ARCH_QWEN,
LLM_ARCH_PHI2,
LLM_ARCH_PLAMO,
+ LLM_ARCH_CODESHELL,
LLM_ARCH_UNKNOWN,
};
@@ -213,6 +214,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PLAMO, "plamo" },
+ { LLM_ARCH_CODESHELL, "codeshell" },
};
enum llm_kv {
@@ -600,6 +602,26 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_CODESHELL,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
@@ -2877,6 +2899,14 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_CODESHELL:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+ switch (hparams.n_layer) {
+ case 42: model.type = e_model::MODEL_SMALL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
@@ -3784,6 +3814,42 @@ static bool llm_load_tensors(
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
}
} break;
+ case LLM_ARCH_CODESHELL:
+ {
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+ model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
+
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+ layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
+
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
+
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+ layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
+
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+ layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
+
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+ layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -5965,6 +6031,117 @@ struct llm_build_context {
return gf;
}
+
+ struct ggml_cgraph * build_codeshell() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
+
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
+
+ // shift the entire K-cache if needed
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+ }
+
+ for (int il = 0; il < n_layer; ++il) {
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
+
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
+
+ struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+ cb(tmpq, "tmpq", il);
+ cb(tmpk, "tmpk", il);
+ cb(Vcur, "Vcur", il);
+
+ struct ggml_tensor * Qcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
+ hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+ cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+ cb(cur, "kqv_out", il);
+ }
+
+ // add the input
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // FF
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm,
+ model.layers[il].ffn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b,
+ NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+ NULL,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ inpL = ggml_add(ctx0, cur, ffn_inp);
+ cb(inpL, "l_out", il);
+ }
+
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph(
@@ -6159,6 +6336,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_gpt2();
} break;
+ case LLM_ARCH_CODESHELL:
+ {
+ result = llm.build_codeshell();
+ } break;
default:
GGML_ASSERT(false);
}