summaryrefslogtreecommitdiff
path: root/examples/train-text-from-scratch
diff options
context:
space:
mode:
Diffstat (limited to 'examples/train-text-from-scratch')
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp4
1 files changed, 2 insertions, 2 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index 31d6620a..79b117df 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
- t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
+ t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
- t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
+ t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
t04->grad = expand(gb, ggml_add_inplace(ctx0,
ggml_add_inplace(ctx0,