diff options
Diffstat (limited to 'examples/train-text-from-scratch/train-text-from-scratch.cpp')
-rw-r--r-- | examples/train-text-from-scratch/train-text-from-scratch.cpp | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 59c90c7b..5f541a14 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs( } }; + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // rope has so much parameters that we make a custom function for it - auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale] + auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale] (struct ggml_tensor * t) -> struct ggml_tensor * { // not capturing these, to silcence warnings - const int n_past = 0; const int rope_mode = 0; return ggml_rope_custom(ctx, - t, n_past, n_rot, rope_mode, n_ctx, + t, KQ_pos, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale); }; @@ -787,6 +795,8 @@ struct ggml_tensor * llama_build_train_graphs( ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); // input gradient ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); + // KQ_pos + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one)); GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad)); ggml_allocr_alloc(alloc, t36->grad); // gradient tensors (will be set to zero by ggml_graph_reset) |