summaryrefslogtreecommitdiff
path: root/examples/train-text-from-scratch/train-text-from-scratch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/train-text-from-scratch/train-text-from-scratch.cpp')
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp16
1 files changed, 13 insertions, 3 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index 59c90c7b..5f541a14 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
}
};
+ // KQ_pos - contains the positions
+ struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
+ {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < N; ++i) {
+ data[i] = n_past + i;
+ }
+ }
+
// rope has so much parameters that we make a custom function for it
- auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
+ auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings
- const int n_past = 0;
const int rope_mode = 0;
return ggml_rope_custom(ctx,
- t, n_past, n_rot, rope_mode, n_ctx,
+ t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
};
@@ -787,6 +795,8 @@ struct ggml_tensor * llama_build_train_graphs(
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+ // KQ_pos
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
ggml_allocr_alloc(alloc, t36->grad);
// gradient tensors (will be set to zero by ggml_graph_reset)