summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp1986
1 files changed, 866 insertions, 1120 deletions
diff --git a/llama.cpp b/llama.cpp
index 42cedc7a..d0c4ef10 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -3090,6 +3090,10 @@ static bool llama_model_load(
return true;
}
+//
+// llm_build
+//
+
using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
enum llm_rope_type {
@@ -3098,17 +3102,35 @@ enum llm_rope_type {
LLM_ROPE_GLM,
};
+enum llm_ffn_op_type {
+ LLM_FFN_SILU,
+ LLM_FFN_GELU,
+ LLM_FFN_RELU,
+ LLM_FFN_RELU_SQR,
+};
+
+enum llm_ffn_gate_type {
+ LLM_FFN_SEQ,
+ LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
+};
+
+enum llm_norm_type {
+ LLM_NORM,
+ LLM_NORM_RMS,
+};
+
static struct ggml_tensor * llm_build_inp_embd(
struct ggml_context * ctx,
+ const llama_hparams & hparams,
const llama_batch & batch,
struct ggml_tensor * tok_embd,
- int64_t n_embd,
- int32_t n_tokens,
const llm_build_cb & cb) {
+ const int64_t n_embd = hparams.n_embd;
+
struct ggml_tensor * inpL;
if (batch.token) {
- struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);
+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
cb(inp_tokens, "inp_tokens", -1);
inpL = ggml_get_rows(ctx, tok_embd, inp_tokens);
@@ -3117,7 +3139,7 @@ static struct ggml_tensor * llm_build_inp_embd(
GGML_ASSERT(false && "not implemented");
#endif
- inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
+ inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
}
return inpL;
@@ -3126,28 +3148,21 @@ static struct ggml_tensor * llm_build_inp_embd(
// Persimmon: n_rot = n_embd_head/2
// Other: n_rot = n_embd_head
static void llm_build_k_shift(
- const llama_context & lctx,
- struct ggml_context * ctx,
- struct ggml_cgraph * graph,
- int64_t n_rot,
- llm_rope_type type,
- const llm_build_cb & cb) {
- const auto & model = lctx.model;
- const auto & kv_self = lctx.kv_self;
- const auto & cparams = lctx.cparams;
-
- const auto & hparams = model.hparams;
-
+ struct ggml_context * ctx,
+ const llama_hparams & hparams,
+ const llama_kv_cache & kv,
+ struct ggml_cgraph * graph,
+ llm_rope_type type,
+ int64_t n_ctx,
+ int64_t n_rot,
+ float freq_base,
+ float freq_scale,
+ const llm_build_cb & cb) {
const int64_t n_layer = hparams.n_layer;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_gqa = hparams.n_embd_gqa();
const int64_t n_embd_head = hparams.n_embd_head();
- const int64_t n_ctx = lctx.cparams.n_ctx;
-
- const float freq_base = cparams.rope_freq_base;
- const float freq_scale = cparams.rope_freq_scale;
-
GGML_ASSERT(n_embd_head % n_rot == 0);
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
@@ -3165,11 +3180,11 @@ static void llm_build_k_shift(
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx,
- ggml_view_3d(ctx, kv_self.k,
+ ggml_view_3d(ctx, kv.k,
n_rot, n_head_kv, n_ctx,
- ggml_element_size(kv_self.k)*n_embd_head,
- ggml_element_size(kv_self.k)*n_embd_gqa,
- ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
+ ggml_element_size(kv.k)*n_embd_head,
+ ggml_element_size(kv.k)*n_embd_gqa,
+ ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(graph, tmp);
@@ -3177,22 +3192,17 @@ static void llm_build_k_shift(
}
static void llm_build_kv_store(
- const llama_context & lctx,
struct ggml_context * ctx,
+ const llama_hparams & hparams,
+ const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur,
+ int64_t n_ctx,
int32_t n_tokens,
int32_t kv_head,
const llm_build_cb & cb,
int64_t il) {
- const auto & model = lctx.model;
- const auto & kv_self = lctx.kv_self;
- const auto & cparams = lctx.cparams;
-
- const auto & hparams = model.hparams;
-
- const int64_t n_ctx = cparams.n_ctx;
const int64_t n_embd_gqa = hparams.n_embd_gqa();
// compute the transposed [n_tokens, n_embd] V matrix
@@ -3200,13 +3210,13 @@ static void llm_build_kv_store(
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
- struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv_self.k, n_tokens*n_embd_gqa,
- (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+ struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa,
+ (ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head));
cb(k_cache_view, "k_cache_view", il);
- struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv_self.v, n_tokens, n_embd_gqa,
- ( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+ struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa,
+ ( n_ctx)*ggml_element_size(kv.v),
+ (il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
@@ -3214,23 +3224,18 @@ static void llm_build_kv_store(
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
}
-enum llm_norm_type {
- LLM_NORM,
- LLM_NORM_RMS,
-};
-
static struct ggml_tensor * llm_build_norm(
struct ggml_context * ctx,
struct ggml_tensor * cur,
+ const llama_hparams & hparams,
struct ggml_tensor * mw,
struct ggml_tensor * mb,
llm_norm_type type,
- float eps,
const llm_build_cb & cb,
int il) {
switch (type) {
- case LLM_NORM: cur = ggml_norm (ctx, cur, eps); break;
- case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break;
+ case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break;
+ case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break;
}
if (mw || mb) {
@@ -3251,18 +3256,6 @@ static struct ggml_tensor * llm_build_norm(
return cur;
}
-enum llm_ffn_op_type {
- LLM_FFN_SILU,
- LLM_FFN_GELU,
- LLM_FFN_RELU,
- LLM_FFN_RELU_SQR,
-};
-
-enum llm_ffn_gate_type {
- LLM_FFN_SEQ,
- LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
-};
-
static struct ggml_tensor * llm_build_ffn(
struct ggml_context * ctx,
struct ggml_tensor * cur,
@@ -3351,26 +3344,21 @@ static struct ggml_tensor * llm_build_ffn(
// if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv(
- const llama_context & lctx,
struct ggml_context * ctx,
struct ggml_tensor * cur,
+ const llama_hparams & hparams,
+ const llama_kv_cache & kv,
struct ggml_tensor * wo,
struct ggml_tensor * wo_b,
struct ggml_tensor * q_cur,
struct ggml_tensor * kq_scale,
struct ggml_tensor * kq_mask,
+ int64_t n_ctx,
int32_t n_tokens,
int32_t n_kv,
- float alibi_bias_max,
+ float max_alibi_bias,
const llm_build_cb & cb,
- int il) {
- const auto & model = lctx.model;
- const auto & kv_self = lctx.kv_self;
- const auto & cparams = lctx.cparams;
-
- const auto & hparams = model.hparams;
-
- const int64_t n_ctx = cparams.n_ctx;
+ int il) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
@@ -3381,11 +3369,11 @@ static struct ggml_tensor * llm_build_kqv(
cb(q, "q", il);
struct ggml_tensor * k =
- ggml_view_3d(ctx, kv_self.k,
+ ggml_view_3d(ctx, kv.k,
n_embd_head, n_kv, n_head_kv,
- ggml_element_size(kv_self.k)*n_embd_gqa,
- ggml_element_size(kv_self.k)*n_embd_head,
- ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+ ggml_element_size(kv.k)*n_embd_gqa,
+ ggml_element_size(kv.k)*n_embd_head,
+ ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il);
cb(k, "k", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@@ -3394,11 +3382,11 @@ static struct ggml_tensor * llm_build_kqv(
kq = ggml_scale(ctx, kq, kq_scale);
cb(kq, "kq_scaled", il);
- if (alibi_bias_max > 0.0f) {
+ if (max_alibi_bias > 0.0f) {
// TODO: n_head or n_head_kv
// TODO: K-shift is likely not working
// TODO: change to ggml_add
- kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, alibi_bias_max);
+ kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
cb(kq, "kq_scaled_alibi", il);
}
@@ -3410,11 +3398,11 @@ static struct ggml_tensor * llm_build_kqv(
// split cached v into n_head heads
struct ggml_tensor * v =
- ggml_view_3d(ctx, kv_self.v,
+ ggml_view_3d(ctx, kv.v,
n_kv, n_embd_head, n_head_kv,
- ggml_element_size(kv_self.v)*n_ctx,
- ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
- ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+ ggml_element_size(kv.v)*n_ctx,
+ ggml_element_size(kv.v)*n_ctx*n_embd_head,
+ ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
@@ -3438,1259 +3426,1011 @@ static struct ggml_tensor * llm_build_kqv(
return cur;
}
-static struct ggml_cgraph * llm_build_llama(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
+struct llm_build_context {
+ const llama_model & model;
+ const llama_hparams & hparams;
+ const llama_cparams & cparams;
+ const llama_batch & batch;
+ const llama_kv_cache & kv_self;
- GGML_ASSERT(!!kv_self.ctx);
+ const int64_t n_embd;
+ const int64_t n_layer;
+ const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
+ const int64_t n_head;
+ const int64_t n_head_kv;
+ const int64_t n_embd_head;
+ const int64_t n_embd_gqa;
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_head_kv = hparams.n_head_kv;
- const int64_t n_embd_head = hparams.n_embd_head();
+ const float freq_base;
+ const float freq_scale;
+ const float norm_eps;
+ const float norm_rms_eps;
- GGML_ASSERT(n_embd_head == hparams.n_rot);
+ const int32_t n_tokens;
+ const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
+ const int32_t kv_head; // index of where we store new KV data in the cache
- const float freq_base = cparams.rope_freq_base;
- const float freq_scale = cparams.rope_freq_scale;
- const float norm_rms_eps = hparams.f_norm_rms_eps;
+ const bool do_rope_shift;
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
+ const llm_build_cb & cb;
- const bool do_rope_shift = worst_case || kv_self.has_shift;
+ llama_buffer & buf_compute;
- //printf("n_kv = %d\n", n_kv);
+ struct ggml_context * ctx0 = nullptr;
- auto & buf_compute = lctx.buf_compute;
-
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ true,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
-
- ggml_cgraph * gf = ggml_new_graph(ctx0);
-
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
-
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
-
- // inp_pos - contains the positions
- struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
- cb(inp_pos, "inp_pos", -1);
-
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
-
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ // TODO: consider making the entire interface noexcept
+ llm_build_context(
+ llama_context & lctx,
+ const llama_batch & batch,
+ const llm_build_cb & cb,
+ bool worst_case) :
+ model (lctx.model),
+ hparams (model.hparams),
+ cparams (lctx.cparams),
+ batch (batch),
+ kv_self (lctx.kv_self),
+ n_embd (hparams.n_embd),
+ n_layer (hparams.n_layer),
+ n_ctx (cparams.n_ctx),
+ n_head (hparams.n_head),
+ n_head_kv (hparams.n_head_kv),
+ n_embd_head (hparams.n_embd_head()),
+ n_embd_gqa (hparams.n_embd_gqa()),
+ freq_base (cparams.rope_freq_base),
+ freq_scale (cparams.rope_freq_scale),
+ norm_eps (hparams.f_norm_eps),
+ norm_rms_eps (hparams.f_norm_rms_eps),
+ n_tokens (batch.n_tokens),
+ n_kv (worst_case ? n_ctx : kv_self.n),
+ kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
+ do_rope_shift (worst_case || kv_self.has_shift),
+ cb (cb),
+ buf_compute (lctx.buf_compute) {
+ GGML_ASSERT(!!kv_self.ctx);
+
+ // all initializations should be done in init()
+ }
+
+ void init() {
+ struct ggml_init_params params = {
+ /*.mem_size =*/ buf_compute.size,
+ /*.mem_buffer =*/ buf_compute.data,
+ /*.no_alloc =*/ true,
+ };
- // shift the entire K-cache if needed
- if (do_rope_shift) {
- llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
+ ctx0 = ggml_init(params);
}
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
-
- // norm
- cur = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, il);
- cb(cur, "attn_norm", il);
-
- // self-attention
- {
- // compute Q and K and RoPE them
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
- cb(Qcur, "Qcur", il);
-
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
- cb(Kcur, "Kcur", il);
-
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
- cb(Vcur, "Vcur", il);
-
- Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
- cb(Qcur, "Qcur", il);
-
- Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
- cb(Kcur, "Kcur", il);
-
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
-
- cur = llm_build_kqv(lctx, ctx0, cur,
- model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
- cb(cur, "kqv_out", il);
- }
-
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
- cb(ffn_inp, "ffn_inp", il);
-
- // feed-forward network
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, il);
- cb(cur, "ffn_norm", il);
-
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, NULL,
- model.layers[il].ffn_gate, NULL,
- model.layers[il].ffn_down, NULL,
- LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
- cb(cur, "ffn_out", il);
+ void free() {
+ if (ctx0) {
+ ggml_free(ctx0);
+ ctx0 = nullptr;
}
-
- cur = ggml_add(ctx0, cur, ffn_inp);
- cb(cur, "l_out", il);
-
- // input for next layer
- inpL = cur;
}
- cur = inpL;
-
- cur = llm_build_norm(ctx0, cur,
- model.output_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, -1);
- cb(cur, "result_norm", -1);
-
- // lm_head
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
-
- ggml_build_forward_expand(gf, cur);
-
- ggml_free(ctx0);
-
- return gf;
-}
-
-static struct ggml_cgraph * llm_build_baichaun(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
-
- GGML_ASSERT(!!kv_self.ctx);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_head_kv = hparams.n_head_kv;
- const int64_t n_embd_head = hparams.n_embd_head();
-
- GGML_ASSERT(n_embd_head == hparams.n_rot);
-
- const float freq_base = cparams.rope_freq_base;
- const float freq_scale = cparams.rope_freq_scale;
- const float norm_rms_eps = hparams.f_norm_rms_eps;
+ struct ggml_cgraph * build_llama() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
- const bool do_rope_shift = worst_case || kv_self.has_shift;
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- auto & buf_compute = lctx.buf_compute;
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ true,
- };
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
- struct ggml_context * ctx0 = ggml_init(params);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ // shift the entire K-cache if needed
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ }
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
- // inp_pos - contains the positions
- struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
- cb(inp_pos, "inp_pos", -1);
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
- // shift the entire K-cache if needed
- if (do_rope_shift) {
- llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
- }
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
+ Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ cb(Qcur, "Qcur", il);
- cur = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, il);
- cb(cur, "attn_norm", il);
+ Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ cb(Kcur, "Kcur", il);
- // self-attention
- {
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
- cb(Qcur, "Qcur", il);
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
- cb(Kcur, "Kcur", il);
+ cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
+ model.layers[il].wo, NULL,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
- cb(Vcur, "Vcur", il);
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
- switch (model.type) {
- case MODEL_7B:
- Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
- Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
- break;
- case MODEL_13B:
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
- break;
- default:
- GGML_ASSERT(false);
+ // feed-forward network
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ model.layers[il].ffn_gate, NULL,
+ model.layers[il].ffn_down, NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
}
- cb(Qcur, "Qcur", il);
- cb(Kcur, "Kcur", il);
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
- // apply ALiBi for 13B model
- const float alibi_bias_max = model.type == MODEL_13B ? 8.0f : -1.0f;
-
- cur = llm_build_kqv(lctx, ctx0, cur,
- model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, alibi_bias_max, cb, il);
- cb(cur, "kqv_out", il);
+ // input for next layer
+ inpL = cur;
}
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
- cb(ffn_inp, "ffn_inp", il);
+ cur = inpL;
- // feed-forward network
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, il);
- cb(cur, "ffn_norm", il);
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, NULL,
- model.layers[il].ffn_gate, NULL,
- model.layers[il].ffn_down, NULL,
- LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
- cb(cur, "ffn_out", il);
- }
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = ggml_add(ctx0, cur, ffn_inp);
- cb(cur, "l_out", il);
+ ggml_build_forward_expand(gf, cur);
- // input for next layer
- inpL = cur;
+ return gf;
}
- cur = inpL;
+ struct ggml_cgraph * build_baichuan() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- cur = llm_build_norm(ctx0, cur,
- model.output_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, -1);
- cb(cur, "result_norm", -1);
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- // lm_head
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- ggml_build_forward_expand(gf, cur);
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
- ggml_free(ctx0);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- return gf;
-}
-
-static struct ggml_cgraph * llm_build_falcon(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
-
- GGML_ASSERT(!!kv_self.ctx);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_head_kv = hparams.n_head_kv;
- const int64_t n_embd_head = hparams.n_embd_head();
- const int64_t n_embd_gqa = hparams.n_embd_gqa();
-
- GGML_ASSERT(n_embd_head == hparams.n_rot);
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- const float freq_base = cparams.rope_freq_base;
- const float freq_scale = cparams.rope_freq_scale;
- const float norm_eps = hparams.f_norm_eps;
-
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
-
- const bool do_rope_shift = worst_case || kv_self.has_shift;
-
- //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
- // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
-
- auto & buf_compute = lctx.buf_compute;
-
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ true,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
-
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ // shift the entire K-cache if needed
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ }
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
- // inp_pos - contains the positions
- struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
- cb(inp_pos, "inp_pos", -1);
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
- // shift the entire K-cache if needed
- if (do_rope_shift) {
- llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb);
- }
+ switch (model.type) {
+ case MODEL_7B:
+ Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ break;
+ case MODEL_13B:
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
+ break;
+ default:
+ GGML_ASSERT(false);
+ }
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * attn_norm;
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- attn_norm = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm,
- model.layers[il].attn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(attn_norm, "attn_norm", il);
+ // apply ALiBi for 13B model
+ const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
- // self-attention
- {
- if (model.layers[il].attn_norm_2) {
- // Falcon-40B
- cur = llm_build_norm(ctx0, attn_norm,
- model.layers[il].attn_norm_2,
- model.layers[il].attn_norm_2_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "attn_norm_2", il);
- } else {
- cur = attn_norm;
+ cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
+ model.layers[il].wo, NULL,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il);
+ cb(cur, "kqv_out", il);
}
- cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
- cb(cur, "wqkv", il);
-
- struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
- struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
- cb(Qcur, "Qcur", il);
- cb(Kcur, "Kcur", il);
- cb(Vcur, "Vcur", il);
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
- // using mode = 2 for neox mode
- Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
- cb(Qcur, "Qcur", il);
-
- Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
- cb(Kcur, "Kcur", il);
+ // feed-forward network
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ model.layers[il].ffn_gate, NULL,
+ model.layers[il].ffn_down, NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
- cur = llm_build_kqv(lctx, ctx0, attn_norm,
- model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
- cb(cur, "kqv_out", il);
+ // input for next layer
+ inpL = cur;
}
- struct ggml_tensor * ffn_inp = cur;
+ cur = inpL;
- // feed forward
- {
- cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
- model.layers[il].ffn_up, NULL,
- NULL, NULL,
- model.layers[il].ffn_down, NULL,
- LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
- cb(cur, "ffn_out", il);
- }
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
- cur = ggml_add(ctx0, cur, ffn_inp);
- cb(cur, "l_out", il);
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = ggml_add(ctx0, cur, inpL);
- cb(cur, "l_out", il);
+ ggml_build_forward_expand(gf, cur);
- // input for next layer
- inpL = cur;
+ return gf;
}
- cur = inpL;
-
- // norm
- cur = llm_build_norm(ctx0, cur,
- model.output_norm,
- model.output_norm_b,
- LLM_NORM, norm_eps, cb, -1);
- cb(cur, "result_norm", -1);
-
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
+ struct ggml_cgraph * build_falcon() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- ggml_build_forward_expand(gf, cur);
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- ggml_free(ctx0);
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- return gf;
-}
-
-static struct ggml_cgraph * llm_build_starcoder(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
-
- GGML_ASSERT(!!kv_self.ctx);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_embd_head = hparams.n_embd_head();
- const int64_t n_embd_gqa = hparams.n_embd_gqa();
-
- GGML_ASSERT(n_embd_head == hparams.n_rot);
-
- const float norm_eps = hparams.f_norm_eps;
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- auto & buf_compute = lctx.buf_compute;
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ true,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
+ // shift the entire K-cache if needed
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ }
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * attn_norm;
- struct ggml_tensor * cur;
- struct ggml_tensor * pos;
- struct ggml_tensor * inpL;
+ attn_norm = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, cb, il);
+ cb(attn_norm, "attn_norm", il);
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
+ // self-attention
+ {
+ if (model.layers[il].attn_norm_2) {
+ // Falcon-40B
+ cur = llm_build_norm(ctx0, attn_norm, hparams,
+ model.layers[il].attn_norm_2,
+ model.layers[il].attn_norm_2_b,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm_2", il);
+ } else {
+ cur = attn_norm;
+ }
- // inp_pos - contains the positions
- struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
- cb(inp_pos, "inp_pos", -1);
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
- pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
- cb(pos, "pos_embd", -1);
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
- inpL = ggml_add(ctx0, inpL, pos);
- cb(inpL, "inpL", -1);
+ // using mode = 2 for neox mode
+ Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+ cb(Qcur, "Qcur", il);
- for (int il = 0; il < n_layer; ++il) {
- cur = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm,
- model.layers[il].attn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "attn_norm", il);
+ Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+ cb(Kcur, "Kcur", il);
- // self-attention
- {
- cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
- cb(cur, "wqkv", il);
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
- cb(cur, "bqkv", il);
+ cur = llm_build_kqv(ctx0, attn_norm, hparams, kv_self,
+ model.layers[il].wo, NULL,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
- struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
- struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+ struct ggml_tensor * ffn_inp = cur;
- cb(Qcur, "Qcur", il);
- cb(Kcur, "Kcur", il);
- cb(Vcur, "Vcur", il);
+ // feed forward
+ {
+ cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
+ model.layers[il].ffn_up, NULL,
+ NULL, NULL,
+ model.layers[il].ffn_down, NULL,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ cur = ggml_add(ctx0, cur, inpL);
+ cb(cur, "l_out", il);
- cur = llm_build_kqv(lctx, ctx0, cur,
- model.layers[il].wo, model.layers[il].bo,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
- cb(cur, "kqv_out", il);
+ // input for next layer
+ inpL = cur;
}
- // add the input
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
- cb(ffn_inp, "ffn_inp", il);
+ cur = inpL;
- // FF
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm,
- model.layers[il].ffn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "ffn_norm", il);
-
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, model.layers[il].ffn_up_b,
- NULL, NULL,
- model.layers[il].ffn_down, model.layers[il].ffn_down_b,
- LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
- cb(cur, "ffn_out", il);
- }
+ // norm
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
- inpL = ggml_add(ctx0, cur, ffn_inp);
- cb(inpL, "l_out", il);
- }
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = llm_build_norm(ctx0, inpL,
- model.output_norm,
- model.output_norm_b,
- LLM_NORM, norm_eps, cb, -1);
- cb(cur, "result_norm", -1);
+ ggml_build_forward_expand(gf, cur);
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
+ return gf;
+ }
- ggml_build_forward_expand(gf, cur);
- ggml_free(ctx0);
+ struct ggml_cgraph * build_starcoder() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- return gf;
-}
+ struct ggml_tensor * cur;
+ struct ggml_tensor * pos;
+ struct ggml_tensor * inpL;
-static struct ggml_cgraph * llm_build_persimmon(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- const auto & kv_self = lctx.kv_self;
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
- GGML_ASSERT(!!kv_self.ctx);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- const auto & cparams = lctx.cparams;
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head_kv = hparams.n_head_kv;
- const int64_t n_head = hparams.n_head;
- const int64_t n_embd_head = hparams.n_embd_head();
- const int64_t n_rot = n_embd_head / 2;
+ pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+ cb(pos, "pos_embd", -1);
- const float freq_base = cparams.rope_freq_base;
- const float freq_scale = cparams.rope_freq_scale;
- const float norm_eps = hparams.f_norm_eps;
+ inpL = ggml_add(ctx0, inpL, pos);
+ cb(inpL, "inpL", -1);
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
+ for (int il = 0; il < n_layer; ++il) {
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
- const bool do_rope_shift = worst_case || kv_self.has_shift;
+ // self-attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
- auto & buf_compute = lctx.buf_compute;
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ true,
- };
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
- struct ggml_context * ctx0 = ggml_init(params);
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "imp_embd", -1);
+ cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
- struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
- cb(inp_pos, "inp_pos", -1);
+ // add the input
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+ cb(ffn_inp, "ffn_inp", il);
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ // FF
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm,
+ model.layers[il].ffn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b,
+ NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ inpL = ggml_add(ctx0, cur, ffn_inp);
+ cb(inpL, "l_out", il);
+ }
- if (do_rope_shift) {
- llm_build_k_shift(lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb);
- }
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * residual = inpL;
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm,
- model.layers[il].attn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "attn_norm", il);
+ ggml_build_forward_expand(gf, cur);
- // self attention
- {
- cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
- cb(cur, "wqkv", il);
-
- cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
- cb(cur, "bqkv", il);
-
- // split qkv
- GGML_ASSERT(n_head_kv == n_head);
-
- struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
- cb(tmpqkv, "tmpqkv", il);
-
- struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
- cb(tmpqkv_perm, "tmpqkv", il);
-
- struct ggml_tensor * tmpq = ggml_view_3d(
- ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
- ggml_element_size(tmpqkv_perm) * n_embd_head,
- ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
- 0
- );
- cb(tmpq, "tmpq", il);
-
- struct ggml_tensor * tmpk = ggml_view_3d(
- ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
- ggml_element_size(tmpqkv_perm) * n_embd_head,
- ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
- ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
- );
- cb(tmpk, "tmpk", il);
-
- // Q/K Layernorm
- tmpq = llm_build_norm(ctx0, tmpq,
- model.layers[il].attn_q_norm,
- model.layers[il].attn_q_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(tmpq, "tmpq", il);
-
- tmpk = llm_build_norm(ctx0, tmpk,
- model.layers[il].attn_k_norm,
- model.layers[il].attn_k_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(tmpk, "tmpk", il);
-
- // RoPE the first n_rot of q/k, pass the other half, and concat.
- struct ggml_tensor * qrot = ggml_view_3d(
- ctx0, tmpq, n_rot, n_head, n_tokens,
- ggml_element_size(tmpq) * n_embd_head,
- ggml_element_size(tmpq) * n_embd_head * n_head,
- 0
- );
- cb(qrot, "qrot", il);
+ return gf;
+ }
- struct ggml_tensor * krot = ggml_view_3d(
- ctx0, tmpk, n_rot, n_head, n_tokens,
- ggml_element_size(tmpk) * n_embd_head,
- ggml_element_size(tmpk) * n_embd_head * n_head,
- 0
- );
- cb(krot, "krot", il);
-
- // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
- struct ggml_tensor * qpass = ggml_view_3d(
- ctx0, tmpq, n_rot, n_head, n_tokens,
- ggml_element_size(tmpq) * n_embd_head,
- ggml_element_size(tmpq) * n_embd_head * n_head,
- ggml_element_size(tmpq) * n_rot
- );
- cb(qpass, "qpass", il);
+ struct ggml_cgraph * build_persimmon() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- struct ggml_tensor * kpass = ggml_view_3d(
- ctx0, tmpk, n_rot, n_head, n_tokens,
- ggml_element_size(tmpk) * n_embd_head,
- ggml_element_size(tmpk) * n_embd_head * n_head,
- ggml_element_size(tmpk) * n_rot
- );
- cb(kpass, "kpass", il);
+ const int64_t n_rot = n_embd_head / 2;
- struct ggml_tensor * qrotated = ggml_rope_custom(
- ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
- );
- cb(qrotated, "qrotated", il);
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- struct ggml_tensor * krotated = ggml_rope_custom(
- ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
- );
- cb(krotated, "krotated", il);
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "imp_embd", -1);
- // ggml currently only supports concatenation on dim=2
- // so we need to permute qrot, qpass, concat, then permute back.
- qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
- cb(qrotated, "qrotated", il);
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
- krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
- cb(krotated, "krotated", il);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
- cb(qpass, "qpass", il);
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
- cb(kpass, "kpass", il);
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ }
- struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
- cb(Qcur, "Qcur", il);
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * residual = inpL;
- struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
- cb(Kcur, "Kcur", il);
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
- struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
- cb(Q, "Q", il);
+ // self attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
+
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
+
+ // split qkv
+ GGML_ASSERT(n_head_kv == n_head);
+
+ struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
+ cb(tmpqkv, "tmpqkv", il);
+
+ struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
+ cb(tmpqkv_perm, "tmpqkv", il);
+
+ struct ggml_tensor * tmpq = ggml_view_3d(
+ ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+ ggml_element_size(tmpqkv_perm) * n_embd_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+ 0
+ );
+ cb(tmpq, "tmpq", il);
+
+ struct ggml_tensor * tmpk = ggml_view_3d(
+ ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+ ggml_element_size(tmpqkv_perm) * n_embd_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
+ );
+ cb(tmpk, "tmpk", il);
+
+ // Q/K Layernorm
+ tmpq = llm_build_norm(ctx0, tmpq, hparams,
+ model.layers[il].attn_q_norm,
+ model.layers[il].attn_q_norm_b,
+ LLM_NORM, cb, il);
+ cb(tmpq, "tmpq", il);
+
+ tmpk = llm_build_norm(ctx0, tmpk, hparams,
+ model.layers[il].attn_k_norm,
+ model.layers[il].attn_k_norm_b,
+ LLM_NORM, cb, il);
+ cb(tmpk, "tmpk", il);
+
+ // RoPE the first n_rot of q/k, pass the other half, and concat.
+ struct ggml_tensor * qrot = ggml_view_3d(
+ ctx0, tmpq, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpq) * n_embd_head,
+ ggml_element_size(tmpq) * n_embd_head * n_head,
+ 0
+ );
+ cb(qrot, "qrot", il);
+
+ struct ggml_tensor * krot = ggml_view_3d(
+ ctx0, tmpk, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpk) * n_embd_head,
+ ggml_element_size(tmpk) * n_embd_head * n_head,
+ 0
+ );
+ cb(krot, "krot", il);
+
+ // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
+ struct ggml_tensor * qpass = ggml_view_3d(
+ ctx0, tmpq, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpq) * n_embd_head,
+ ggml_element_size(tmpq) * n_embd_head * n_head,
+ ggml_element_size(tmpq) * n_rot
+ );
+ cb(qpass, "qpass", il);
+
+ struct ggml_tensor * kpass = ggml_view_3d(
+ ctx0, tmpk, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpk) * n_embd_head,
+ ggml_element_size(tmpk) * n_embd_head * n_head,
+ ggml_element_size(tmpk) * n_rot
+ );
+ cb(kpass, "kpass", il);
+
+ struct ggml_tensor * qrotated = ggml_rope_custom(
+ ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
+ );
+ cb(qrotated, "qrotated", il);
+
+ struct ggml_tensor * krotated = ggml_rope_custom(
+ ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
+ );
+ cb(krotated, "krotated", il);
+
+ // ggml currently only supports concatenation on dim=2
+ // so we need to permute qrot, qpass, concat, then permute back.
+ qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
+ cb(qrotated, "qrotated", il);
+
+ krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
+ cb(krotated, "krotated", il);
+
+ qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
+ cb(qpass, "qpass", il);
+
+ kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
+ cb(kpass, "kpass", il);
+
+ struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
+ cb(Qcur, "Qcur", il);
+
+ struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
+ cb(Q, "Q", il);
+
+ Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
+ cb(Kcur, "Kcur", il);
+
+ struct ggml_tensor * Vcur = ggml_view_3d(
+ ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+ ggml_element_size(tmpqkv_perm) * n_embd_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
+ );
+ cb(Vcur, "Vcur", il);
+
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+ // TODO: not tested, could be broken
+ cur = llm_build_kqv(ctx0, Q, hparams, kv_self,
+ model.layers[il].wo, model.layers[il].bo,
+ Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
- Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
- cb(Kcur, "Kcur", il);
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
+ cb(ffn_inp, "ffn_inp", il);
- struct ggml_tensor * Vcur = ggml_view_3d(
- ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
- ggml_element_size(tmpqkv_perm) * n_embd_head,
- ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
- ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
- );
- cb(Vcur, "Vcur", il);
+ // feed-forward network
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm,
+ model.layers[il].ffn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b,
+ NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+ LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
- // TODO: not tested, could be broken
- cur = llm_build_kqv(lctx, ctx0, Q,
- model.layers[il].wo, model.layers[il].bo,
- Q, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
- cb(cur, "kqv_out", il);
+ inpL = cur;
}
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
- cb(ffn_inp, "ffn_inp", il);
+ cur = inpL;
- // feed-forward network
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm,
- model.layers[il].ffn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "ffn_norm", il);
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, model.layers[il].ffn_up_b,
- NULL, NULL,
- model.layers[il].ffn_down, model.layers[il].ffn_down_b,
- LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
- cb(cur, "ffn_out", il);
- }
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = ggml_add(ctx0, cur, ffn_inp);
- cb(cur, "l_out", il);
+ ggml_build_forward_expand(gf, cur);
- inpL = cur;
+ return gf;
}
- cur = inpL;
-
- cur = llm_build_norm(ctx0, cur,
- model.output_norm,
- model.output_norm_b,
- LLM_NORM, norm_eps, cb, -1);
- cb(cur, "result_norm", -1);
-
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
-
- ggml_build_forward_expand(gf, cur);
-
- ggml_free(ctx0);
-
- return gf;
-}
-
-static struct ggml_cgraph * llm_build_refact(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
-
- GGML_ASSERT(!!kv_self.ctx);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_head_kv = hparams.n_head_kv;
- const int64_t n_embd_head = hparams.n_embd_head();
-
- const float norm_rms_eps = hparams.f_norm_rms_eps;
-
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
+ struct ggml_cgraph * build_refact() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- auto & buf_compute = lctx.buf_compute;
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ true,
- };
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- struct ggml_context * ctx0 = ggml_init(params);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ // self-attention
+ {
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * inpSA = inpL;
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
- cur = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, il);
- cb(cur, "attn_norm", il);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ cb(Kcur, "Kcur", il);
- // self-attention
- {
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
- cb(Qcur, "Qcur", il);
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ cb(Qcur, "Qcur", il);
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
- cb(Kcur, "Kcur", il);
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
- cb(Vcur, "Vcur", il);
+ cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
+ model.layers[il].wo, NULL,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
- cb(Kcur, "Kcur", il);
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- cb(Qcur, "Qcur", il);
+ // feed-forward network
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ model.layers[il].ffn_gate, NULL,
+ model.layers[il].ffn_down, NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
- cur = llm_build_kqv(lctx, ctx0, Qcur,
- model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, 8.0f, cb, il);
- cb(cur, "kqv_out", il);
+ // input for next layer
+ inpL = cur;
}
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
- cb(ffn_inp, "ffn_inp", il);
+ cur = inpL;
- // feed-forward network
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, il);
- cb(cur, "ffn_norm", il);
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, NULL,
- model.layers[il].ffn_gate, NULL,
- model.layers[il].ffn_down, NULL,
- LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
- cb(cur, "ffn_out", il);
- }
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = ggml_add(ctx0, cur, ffn_inp);
- cb(cur, "l_out", il);
+ ggml_build_forward_expand(gf, cur);
- // input for next layer
- inpL = cur;
+ return gf;
}
- cur = inpL;
-
- cur = llm_build_norm(ctx0, cur,
- model.output_norm, NULL,
- LLM_NORM_RMS, norm_rms_eps, cb, -1);
- cb(cur, "result_norm", -1);
-
- // lm_head
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
-
- ggml_build_forward_expand(gf, cur);
-
- ggml_free(ctx0);
-
- return gf;
-}
-
-static struct ggml_cgraph * llm_build_bloom(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
-
- GGML_ASSERT(!!kv_self.ctx);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_embd_head = hparams.n_embd_head();
- const int64_t n_embd_gqa = hparams.n_embd_gqa();
-
- GGML_ASSERT(n_embd_head == hparams.n_rot);
-
- const float norm_eps = hparams.f_norm_eps;
-
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
-
- auto & buf_compute = lctx.buf_compute;
-
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
- };
+ struct ggml_cgraph * build_bloom() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- params.no_alloc = true;
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- struct ggml_context * ctx0 = ggml_init(params);
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- ggml_cgraph * gf = ggml_new_graph(ctx0);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
+ inpL = llm_build_norm(ctx0, inpL, hparams,
+ model.tok_norm,
+ model.tok_norm_b,
+ LLM_NORM, cb, -1);
+ cb(inpL, "inp_norm", -1);
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ for (int il = 0; il < n_layer; ++il) {
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ model.layers[il].attn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "attn_norm", il);
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ // self-attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
- inpL = llm_build_norm(ctx0, inpL,
- model.tok_norm,
- model.tok_norm_b,
- LLM_NORM, norm_eps, cb, -1);
- cb(inpL, "inp_norm", -1);
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
- for (int il = 0; il < n_layer; ++il) {
- cur = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm,
- model.layers[il].attn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "attn_norm", il);
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
- // self-attention
- {
- cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
- cb(cur, "wqkv", il);
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
- cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
- cb(cur, "bqkv", il);
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
- struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- cb(Qcur, "Qcur", il);
- cb(Kcur, "Kcur", il);
- cb(Vcur, "Vcur", il);
+ cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ // Add the input
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+ cb(ffn_inp, "ffn_inp", il);
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ // FF
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm,
+ model.layers[il].ffn_norm_b,
+ LLM_NORM, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b,
+ NULL, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- cur = llm_build_kqv(lctx, ctx0, Qcur,
- model.layers[il].wo, model.layers[il].bo,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, 8.0f, cb, il);
- cb(cur, "kqv_out", il);
+ inpL = ggml_add(ctx0, cur, ffn_inp);
+ cb(inpL, "l_out", il);
}
- // Add the input
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
- cb(ffn_inp, "ffn_inp", il);
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm,
+ model.output_norm_b,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
- // FF
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm,
- model.layers[il].ffn_norm_b,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "ffn_norm", il);
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, model.layers[il].ffn_up_b,
- NULL, NULL,
- model.layers[il].ffn_down, model.layers[il].ffn_down_b,
- LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
- cb(cur, "ffn_out", il);
- }
+ ggml_build_forward_expand(gf, cur);
- inpL = ggml_add(ctx0, cur, ffn_inp);
- cb(inpL, "l_out", il);
+ return gf;
}
- cur = llm_build_norm(ctx0, inpL,
- model.output_norm,
- model.output_norm_b,
- LLM_NORM, norm_eps, cb, -1);
- cb(cur, "result_norm", -1);
+ struct ggml_cgraph * build_mpt() {
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
- ggml_build_forward_expand(gf, cur);
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
- ggml_free(ctx0);
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
- return gf;
-}
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
-static struct ggml_cgraph * llm_build_mpt(
- llama_context & lctx,
- const llama_batch & batch,
- const llm_build_cb & cb,
- bool worst_case) {
- const auto & model = lctx.model;
- const auto & hparams = model.hparams;
- const auto & cparams = lctx.cparams;
-
- const auto & kv_self = lctx.kv_self;
-
- GGML_ASSERT(!!kv_self.ctx);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_layer = hparams.n_layer;
- const int64_t n_ctx = cparams.n_ctx;
- const int64_t n_head = hparams.n_head;
- const int64_t n_embd_head = hparams.n_embd_head();
- const int64_t n_embd_gqa = hparams.n_embd_gqa();
-
- const float norm_eps = hparams.f_norm_eps;
- const float clamp_kqv = hparams.f_clamp_kqv;
- const float max_alibi_bias = hparams.f_max_alibi_bias;
-
- const int32_t n_tokens = batch.n_tokens;
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
-
- auto & buf_compute = lctx.buf_compute;
-
- struct ggml_init_params params = {
- /*.mem_size =*/ buf_compute.size,
- /*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
- };
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * attn_norm;
- params.no_alloc = true;
-
- struct ggml_context * ctx0 = ggml_init(params);
-
- ggml_cgraph * gf = ggml_new_graph(ctx0);
-
- struct ggml_tensor * cur;
- struct ggml_tensor * inpL;
+ attn_norm = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm,
+ NULL,
+ LLM_NORM, cb, il);
+ cb(attn_norm, "attn_norm", il);
- inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
- cb(inpL, "inp_embd", -1);
+ // self-attention
+ {
+ cur = attn_norm;
- // KQ_scale
- struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
- cb(KQ_scale, "KQ_scale", -1);
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
- // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
- cb(KQ_mask, "KQ_mask", -1);
+ if (hparams.f_clamp_kqv > 0.0f) {
+ cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+ cb(cur, "wqkv_clamped", il);
+ }
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * attn_norm;
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
- attn_norm = llm_build_norm(ctx0, inpL,
- model.layers[il].attn_norm,
- NULL,
- LLM_NORM, norm_eps, cb, il);
- cb(attn_norm, "attn_norm", il);
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
- // self-attention
- {
- cur = attn_norm;
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
- cb(cur, "wqkv", il);
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
- if (clamp_kqv > 0.0f) {
- cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv);
- cb(cur, "wqkv_clamped", il);
+ cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
+ model.layers[il].wo, NULL,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il);
+ cb(cur, "kqv_out", il);
}
- struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
- struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
- cb(Qcur, "Qcur", il);
- cb(Kcur, "Kcur", il);
- cb(Vcur, "Vcur", il);
+ // Add the input
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+ cb(ffn_inp, "ffn_inp", il);
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ // feed forward
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm,
+ NULL,
+ LLM_NORM, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ NULL, NULL,
+ model.layers[il].ffn_down, NULL,
+ LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+ cb(cur, "ffn_out", il);
+ }
- llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
- cur = llm_build_kqv(lctx, ctx0, Qcur,
- model.layers[il].wo, NULL,
- Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, max_alibi_bias, cb, il);
- cb(cur, "kqv_out", il);
+ // input for next layer
+ inpL = cur;
}
- // Add the input
- struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
- cb(ffn_inp, "ffn_inp", il);
+ cur = inpL;
- // feed forward
- {
- cur = llm_build_norm(ctx0, ffn_inp,
- model.layers[il].ffn_norm,
- NULL,
- LLM_NORM, norm_eps, cb, il);
- cb(cur, "ffn_norm", il);
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm,
+ NULL,
+ LLM_NORM, cb, -1);
+ cb(cur, "result_norm", -1);
- cur = llm_build_ffn(ctx0, cur,
- model.layers[il].ffn_up, NULL,
- NULL, NULL,
- model.layers[il].ffn_down, NULL,
- LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
- cb(cur, "ffn_out", il);
- }
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
- cur = ggml_add(ctx0, cur, ffn_inp);
- cb(cur, "l_out", il);
+ ggml_build_forward_expand(gf, cur);
- // input for next layer
- inpL = cur;
+ return gf;
}
-
- cur = inpL;
-
- cur = llm_build_norm(ctx0, cur,
- model.output_norm,
- NULL,
- LLM_NORM, norm_eps, cb, -1);
- cb(cur, "result_norm", -1);
-
- cur = ggml_mul_mat(ctx0, model.output, cur);
- cb(cur, "result_output", -1);
-
- ggml_build_forward_expand(gf, cur);
-
- ggml_free(ctx0);
-
- return gf;
-}
+};
//
// tensor offloading helpers
@@ -5122,43 +4862,49 @@ static struct ggml_cgraph * llama_build_graph(
struct ggml_cgraph * result = NULL;
+ struct llm_build_context llm(lctx, batch, cb, worst_case);
+
+ llm.init();
+
switch (model.arch) {
case LLM_ARCH_LLAMA:
{
- result = llm_build_llama(lctx, batch, cb, worst_case);
+ result = llm.build_llama();
} break;
case LLM_ARCH_BAICHUAN:
{
- result = llm_build_baichaun(lctx, batch, cb, worst_case);
+ result = llm.build_baichuan();
} break;
case LLM_ARCH_FALCON:
{
- result = llm_build_falcon(lctx, batch, cb, worst_case);
+ result = llm.build_falcon();
} break;
case LLM_ARCH_STARCODER:
{
- result = llm_build_starcoder(lctx, batch, cb, worst_case);
+ result = llm.build_starcoder();
} break;
case LLM_ARCH_PERSIMMON:
{
- result = llm_build_persimmon(lctx, batch, cb, worst_case);
+ result = llm.build_persimmon();
} break;
case LLM_ARCH_REFACT:
{
- result = llm_build_refact(lctx, batch, cb, worst_case);
+ result = llm.build_refact();
} break;
case LLM_ARCH_BLOOM:
{
- result = llm_build_bloom(lctx, batch, cb, worst_case);
+ result = llm.build_bloom();
} break;
case LLM_ARCH_MPT:
{
- result = llm_build_mpt(lctx, batch, cb, worst_case);
+ result = llm.build_mpt();
} break;
default:
GGML_ASSERT(false);
}
+ llm.free();
+
if (worst_case) {
int n_non_view_total = 0;