diff options
Diffstat (limited to 'examples/train-text-from-scratch/train-text-from-scratch.cpp')
-rw-r--r-- | examples/train-text-from-scratch/train-text-from-scratch.cpp | 14 |
1 files changed, 5 insertions, 9 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index f7ed6336..4a9a2340 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs( checkpoints.push_back(t00); checkpoints.push_back(t01); - struct ggml_tensor * kv_scale = NULL; - if (!enable_flash_attn) { - kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head)); - } + const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head); for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; @@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs( // make sure some tensors are not reallocated by inserting new temporary nodes depending on them int n_leafs_before = gb->n_leafs; int n_nodes_before = gb->n_nodes; - struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f); // output tensors - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f)); // input gradient - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f)); // KQ_pos - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f)); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL); ggml_allocr_alloc(alloc, t36->grad); |