summaryrefslogtreecommitdiff
path: root/examples/finetune/finetune.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-12-21 23:20:49 +0200
committerGitHub <noreply@github.com>2023-12-21 23:20:49 +0200
commitafefa319f1f59b002dfa0d1ef407a2c74bd9770b (patch)
treea6923e0a6214293d88957cd11e25943f2c0fb80a /examples/finetune/finetune.cpp
parent769a7bc85eaa44e3d7eadf39abfeff7bb0b9cc2f (diff)
ggml : change ggml_scale to take a float instead of tensor (#4573)
* ggml : change ggml_scale to take a float instead of tensor * ggml : fix CPU implementation * tests : fix test-grad0 ggml-ci
Diffstat (limited to 'examples/finetune/finetune.cpp')
-rw-r--r--examples/finetune/finetune.cpp42
1 files changed, 20 insertions, 22 deletions
diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp
index 6a668d76..7b1333a9 100644
--- a/examples/finetune/finetune.cpp
+++ b/examples/finetune/finetune.cpp
@@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
float rope_freq_scale = 1.0f;
GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
- GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+ GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (rope_freq_scale != 1.0f) {
hparams->rope_freq_scale = 1.0f / rope_freq_scale;
}
@@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_rot = hparams.n_embd_head();
const int n_embd_head = hparams.n_embd_head();
const int n_embd_gqa = hparams.n_embd_gqa();
+
const float rms_norm_eps = hparams.f_norm_rms_eps;
const float rope_freq_base = hparams.rope_freq_base;
const float rope_freq_scale = hparams.rope_freq_scale;
@@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
checkpoints.push_back(t01);
}
- struct ggml_tensor * kv_scale = NULL;
- if (!enable_flash_attn) {
- kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
- }
+ const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
@@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
int n_leafs_before = gb->n_leafs;
int n_nodes_before = gb->n_nodes;
- struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
+
// output tensors
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
// input gradient
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad);
// KQ_pos
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
// make sure base model tensors data cannot be used in viewable operations
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
- ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
}
// allocating checkpoints in one block to reduce memory fragmentation