diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 192 |
1 files changed, 191 insertions, 1 deletions
@@ -211,6 +211,7 @@ enum llm_arch { LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, LLM_ARCH_PHI2, + LLM_ARCH_PHI3, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -246,6 +247,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, @@ -794,6 +796,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA }, }, { + LLM_ARCH_PHI3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { LLM_ARCH_PLAMO, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, @@ -3961,6 +3980,16 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PHI3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PLAMO: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4352,6 +4381,7 @@ static void llm_load_vocab( //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL && (t.first == "<|eot_id|>" || t.first == "<|im_end|>" || + t.first == "<|end|>" || t.first == "<end_of_turn>" ) ) { @@ -5375,6 +5405,33 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; + case LLM_ARCH_PHI3: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context* ctx_layer = ctx_for_layer(i); + ggml_context* ctx_split = ctx_for_layer_split(i); + + auto& layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + } + } break; case LLM_ARCH_PLAMO: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6326,7 +6383,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -8967,12 +9024,140 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); + ggml_build_forward_expand(gf, cur); + return gf; + } + + struct ggml_cgraph * build_phi3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + auto residual = inpL; + + // self-attention + { + struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(attn_norm_output, "attn_norm", il); + + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + } + else { + Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor* inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, residual); + residual = cur; + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // FF + // special-case: the up and gate tensors are merged into a single tensor + // TOOD: support into llm_build_ffn + { + struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0)); + auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2)); + + y = ggml_mul(ctx0, y, ggml_silu(ctx0, g)); + cb(y, "ffn_gate", il); + + auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y); + cb(down, "ffn_down", il); + + cur = down; + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, residual, cur); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); return gf; } + struct ggml_cgraph * build_plamo() { struct ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -10474,6 +10659,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_phi2(); } break; + case LLM_ARCH_PHI3: + { + result = llm.build_phi3(); + } break; case LLM_ARCH_PLAMO: { result = llm.build_plamo(); @@ -15393,6 +15582,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: return LLAMA_ROPE_TYPE_NEOX; |