diff options
author | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-19 18:51:41 +0300 |
---|---|---|
committer | Iwan Kawrakow <iwan.kawrakow@gmail.com> | 2024-06-22 12:02:52 +0300 |
commit | d08ff0df433ee9dd8643afe1cf501c4154067cd2 (patch) | |
tree | 577f5e609086ce5018b773256f5605487ded3d51 | |
parent | ad60fb35677c6fffdc0b17ac61f1796f416a8e8f (diff) |
Revert "bitnet(scale in a separate tensor): replace ggml_mul with ggml_scale"
This reverts commit f83381371b61e0863b55c60e5f5df139126a496d.
When using CUDA, the tensor contents have not been loaded yet,
so we crash when trying to access the scale when building the
graph. There must be a better way.
-rw-r--r-- | llama.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
@@ -11826,7 +11826,7 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - Qcur = ggml_scale(ctx0, Qcur, *(const float *)model.layers[il].wq_scale->data); + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -11835,7 +11835,7 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - Kcur = ggml_scale(ctx0, Kcur, *(const float *)model.layers[il].wk_scale->data); + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -11844,7 +11844,7 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - Vcur = ggml_scale(ctx0, Vcur, *(const float *)model.layers[il].wv_scale->data); + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11938,7 +11938,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur_attn); cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn); - cur = ggml_scale(ctx0, cur, *(const float *)model.layers[il].wo_scale->data); + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cb(cur, "kqv_out", il); } @@ -11961,12 +11961,12 @@ struct llm_build_context { cb(cur, "ffn_norm", il); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); - tmp = ggml_scale(ctx0, tmp, *(const float *)model.layers[il].ffn_up_scale->data); - + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); + cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); - cur = ggml_scale(ctx0, cur, *(const float *)model.layers[il].ffn_gate_scale->data); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); cb(cur, "ffn_gate", il); @@ -11983,7 +11983,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); - cur = ggml_scale(ctx0, cur, *(const float *)model.layers[il].ffn_down_scale->data); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); |