diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 10:18:41 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | 8c936e3d6593bec82975ba93bec05f9f03bb21f3 (patch) | |
tree | 905607768a802ee341ab95a682a40529db913d92 | |
parent | fc04994ebf8bfcb988a913cdd331bb120389bc44 (diff) |
bitnet: replace ggml_mul with ggml_scale to apply the scales
Also save one scale operation in the ffn network by adjusting
rms_eps. We gain up to 3% in performance by doing this, but it
is a bit of a hack (we store the tensor scales in op_params
while loading the model).
-rw-r--r-- | ggml.c | 2 | ||||
-rw-r--r-- | llama.cpp | 57 |
2 files changed, 44 insertions, 15 deletions
@@ -18939,7 +18939,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ //n_tasks = MIN(n_threads, ggml_nelements(node->src[1])); n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1])); } break; - case GGML_OP_SCALE: case GGML_OP_SET: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -18963,6 +18962,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = 1; //TODO } break; + case GGML_OP_SCALE: case GGML_OP_SOFT_MAX: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); @@ -6884,6 +6884,27 @@ static bool llm_load_tensors( } } + if (model.arch == LLM_ARCH_BITNET) { + auto set_scale = [] (ggml_tensor * w, ggml_tensor * s) { + float scale = 1; + if (ggml_backend_buffer_is_host(s->buffer)) { + scale = *(const float *)s->data; + } else { + ggml_backend_tensor_get(s, &scale, 0, sizeof(float)); + } + std::memcpy(w->op_params, &scale, sizeof(scale)); + }; + for (auto& l : model.layers) { + set_scale(l.ffn_up, l.ffn_up_scale); + set_scale(l.ffn_gate, l.ffn_gate_scale); + set_scale(l.ffn_down, l.ffn_down_scale); + set_scale(l.wq, l.wq_scale); + set_scale(l.wk, l.wk_scale); + set_scale(l.wv, l.wv_scale); + set_scale(l.wo, l.wo_scale); + } + } + // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = ggml_time_us() - model.t_start_us; @@ -7061,10 +7082,10 @@ static struct ggml_tensor * llm_build_norm( struct ggml_tensor * mb, llm_norm_type type, const llm_build_cb & cb, - int il) { + int il, float scale_eps = 1) { switch (type) { case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; - case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, scale_eps * hparams.f_norm_rms_eps); break; } if (mw || mb) { @@ -11822,13 +11843,14 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); -#define BITNET_MUL ggml_mul - // self-attention { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - Qcur = BITNET_MUL(ctx0, Qcur, model.layers[il].wq_scale); + float q_scale; std::memcpy(&q_scale, model.layers[il].wq->op_params, sizeof(float)); + // Note: we could save this scale operation by applying the Q scale K * Q further down + // (which laso uses a scale). This works on the CPU and Metal backends, but produces NaNs on CUDA. + Qcur = ggml_scale(ctx0, Qcur, q_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -11837,7 +11859,8 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - Kcur = BITNET_MUL(ctx0, Kcur, model.layers[il].wk_scale); + float k_scale; std::memcpy(&k_scale, model.layers[il].wk->op_params, sizeof(float)); + Kcur = ggml_scale(ctx0, Kcur, k_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -11846,7 +11869,8 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - Vcur = BITNET_MUL(ctx0, Vcur, model.layers[il].wv_scale); + float v_scale; std::memcpy(&v_scale, model.layers[il].wv->op_params, sizeof(float)); + Vcur = ggml_scale(ctx0, Vcur, v_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11878,6 +11902,8 @@ struct llm_build_context { const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + // We would use this if we did not apply the Q scale above. Sadly, this fails on CUDA. + //float kq_scale = q_scale/sqrtf(float(n_embd_head)); struct ggml_tensor * cur_attn; struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); cb(q, "q", il); @@ -11940,7 +11966,8 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur_attn); cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); - cur = BITNET_MUL(ctx0, cur, model.layers[il].wo_scale); + float wo_scale; std::memcpy(&wo_scale, model.layers[il].wo->op_params, sizeof(float)); + cur = ggml_scale(ctx0, cur, wo_scale); cb(cur, "kqv_out", il); } @@ -11963,29 +11990,32 @@ struct llm_build_context { cb(cur, "ffn_norm", il); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); - tmp = BITNET_MUL(ctx0, tmp, model.layers[il].ffn_up_scale); + float ffn_up_scale; std::memcpy(&ffn_up_scale, model.layers[il].ffn_up->op_params, sizeof(float)); cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); - cur = BITNET_MUL(ctx0, cur, model.layers[il].ffn_gate_scale); + float ffn_gate_scale; std::memcpy(&ffn_gate_scale, model.layers[il].ffn_gate->op_params, sizeof(float)); + cur = ggml_scale(ctx0, cur, ffn_gate_scale); cb(cur, "ffn_gate", il); + // combine this with the above scale into ggml_scaled_silu cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); - cur = BITNET_MUL(ctx0, cur, tmp); + cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_sub_norm, NULL, - LLM_NORM_RMS, cb, il); + LLM_NORM_RMS, cb, il, 1/(ffn_up_scale*ffn_up_scale)); cb(cur, "ffn_sub_norm", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); - cur = BITNET_MUL(ctx0, cur, model.layers[il].ffn_down_scale); + float ffn_down_scale; std::memcpy(&ffn_down_scale, model.layers[il].ffn_down->op_params, sizeof(float)); + cur = ggml_scale(ctx0, cur, ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -12009,7 +12039,6 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); return gf; } -#undef BITNET_MUL }; |